{-# LANGUAGE FlexibleContexts #-}

{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}

-- |
-- Module      :  Numeric.GSL.Internal
-- Copyright   :  (c) Alberto Ruiz 2009
-- License     :  GPL
-- Maintainer  :  Alberto Ruiz
-- Stability   :  provisional
--
--
-- Auxiliary functions.
--


module Numeric.GSL.Internal(
    iv,
    mkVecfun,
    mkVecVecfun,
    mkDoubleVecVecfun,
    mkDoublefun,
    aux_vTov,
    mkVecMatfun,
    mkDoubleVecMatfun,
    aux_vTom,
    createV,
    createMIO,
    module Numeric.LinearAlgebra.Devel,
    check,(#),(#!),vec, ww2,
    Res,TV,TM,TCV,TCM
) where

import Numeric.LinearAlgebra.HMatrix
import Numeric.LinearAlgebra.Devel hiding (check)

import Foreign.Marshal.Array(copyArray)
import Foreign.Ptr(Ptr, FunPtr)
import Foreign.C.Types
import Foreign.C.String(peekCString)
import System.IO.Unsafe(unsafePerformIO)
import Data.Vector.Storable as V (unsafeWith,length)
import Control.Monad(when)

iv :: (Vector Double -> Double) -> (CInt -> Ptr Double -> Double)
iv :: (Vector Double -> Double) -> CInt -> Ptr Double -> Double
iv Vector Double -> Double
f CInt
n Ptr Double
p = Vector Double -> Double
f (Int -> (CInt -> Ptr Double -> IO CInt) -> String -> Vector Double
forall {a}.
Storable a =>
Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
n) CInt -> Ptr Double -> IO CInt
forall {a} {b}. (Integral a, Num b) => a -> Ptr Double -> IO b
copy String
"iv") where
    copy :: a -> Ptr Double -> IO b
copy a
n' Ptr Double
q = do
        Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
q Ptr Double
p (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n')
        b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
0

-- | conversion of Haskell functions into function pointers that can be used in the C side
foreign import ccall safe "wrapper"
    mkVecfun :: (CInt -> Ptr Double -> Double)
             -> IO( FunPtr (CInt -> Ptr Double -> Double))

foreign import ccall safe "wrapper"
    mkVecVecfun :: TVV -> IO (FunPtr TVV)

foreign import ccall safe "wrapper"
    mkDoubleVecVecfun :: (Double -> TVV) -> IO (FunPtr (Double -> TVV))

foreign import ccall safe "wrapper"
    mkDoublefun :: (Double -> Double) -> IO (FunPtr (Double -> Double))

aux_vTov :: (Vector Double -> Vector Double) -> TVV
aux_vTov :: (Vector Double -> Vector Double) -> TVV
aux_vTov Vector Double -> Vector Double
f CInt
n Ptr Double
p CInt
nr Ptr Double
r = IO CInt
g where
    v :: Vector Double
v = Vector Double -> Vector Double
f Vector Double
x
    x :: Vector Double
x = Int -> (CInt -> Ptr Double -> IO CInt) -> String -> Vector Double
forall {a}.
Storable a =>
Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
n) CInt -> Ptr Double -> IO CInt
forall {a} {b}. (Integral a, Num b) => a -> Ptr Double -> IO b
copy String
"aux_vTov"
    copy :: a -> Ptr Double -> IO b
copy a
n' Ptr Double
q = do
        Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
q Ptr Double
p (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n')
        b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
0
    g :: IO CInt
g = do Vector Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector Double
v ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Double
p' -> Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
r Ptr Double
p' (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
nr)
           CInt -> IO CInt
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
0

foreign import ccall safe "wrapper"
    mkVecMatfun :: TVM -> IO (FunPtr TVM)

foreign import ccall safe "wrapper"
    mkDoubleVecMatfun :: (Double -> TVM) -> IO (FunPtr (Double -> TVM))

aux_vTom :: (Vector Double -> Matrix Double) -> TVM
aux_vTom :: (Vector Double -> Matrix Double) -> TVM
aux_vTom Vector Double -> Matrix Double
f CInt
n Ptr Double
p CInt
rr CInt
cr Ptr Double
r = IO CInt
g where
    v :: Vector Double
v = Matrix Double -> Vector Double
forall t. Element t => Matrix t -> Vector t
flatten (Matrix Double -> Vector Double) -> Matrix Double -> Vector Double
forall a b. (a -> b) -> a -> b
$ Vector Double -> Matrix Double
f Vector Double
x
    x :: Vector Double
x = Int -> (CInt -> Ptr Double -> IO CInt) -> String -> Vector Double
forall {a}.
Storable a =>
Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
n) CInt -> Ptr Double -> IO CInt
forall {a} {b}. (Integral a, Num b) => a -> Ptr Double -> IO b
copy String
"aux_vTov"
    copy :: a -> Ptr Double -> IO b
copy a
n' Ptr Double
q = do
        Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
q Ptr Double
p (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n')
        b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
0
    g :: IO CInt
g = do Vector Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector Double
v ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Double
p' -> Ptr Double -> Ptr Double -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr Double
r Ptr Double
p' (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> CInt -> Int
forall a b. (a -> b) -> a -> b
$ CInt
rrCInt -> CInt -> CInt
forall a. Num a => a -> a -> a
*CInt
cr)
           CInt -> IO CInt
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
0

createV :: Int -> (CInt -> Ptr a -> IO CInt) -> String -> Vector a
createV Int
n CInt -> Ptr a -> IO CInt
fun String
msg = IO (Vector a) -> Vector a
forall a. IO a -> a
unsafePerformIO (IO (Vector a) -> Vector a) -> IO (Vector a) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
    Vector a
r <- Int -> IO (Vector a)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
n
    (Vector a
r Vector a
-> (IO CInt -> IO CInt) -> TransRaw (Vector a) (IO CInt) -> IO CInt
forall {c} {b} {r}.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# IO CInt -> IO CInt
forall a. a -> a
id) TransRaw (Vector a) (IO CInt)
CInt -> Ptr a -> IO CInt
fun IO CInt -> String -> IO ()
#| String
msg
    Vector a -> IO (Vector a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Vector a
r

createMIO :: Int
-> Int
-> (CInt -> CInt -> Ptr a -> IO CInt)
-> String
-> IO (Matrix a)
createMIO Int
r Int
c CInt -> CInt -> Ptr a -> IO CInt
fun String
msg = do
    Matrix a
res <- MatrixOrder -> Int -> Int -> IO (Matrix a)
forall a. Storable a => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix MatrixOrder
RowMajor Int
r Int
c
    (Matrix a
res Matrix a
-> (IO CInt -> IO CInt) -> TransRaw (Matrix a) (IO CInt) -> IO CInt
forall {c} {b} {r}.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# IO CInt -> IO CInt
forall a. a -> a
id) TransRaw (Matrix a) (IO CInt)
CInt -> CInt -> Ptr a -> IO CInt
fun IO CInt -> String -> IO ()
#| String
msg
    Matrix a -> IO (Matrix a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Matrix a
res

--------------------------------------------------------------------------------

-- | check the error code
check :: String -> IO CInt -> IO ()
check :: String -> IO CInt -> IO ()
check String
msg IO CInt
f = do
    CInt
err <- IO CInt
f
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
errCInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/=CInt
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Ptr CChar
ps <- CInt -> IO (Ptr CChar)
gsl_strerror CInt
err
        String
s <- Ptr CChar -> IO String
peekCString Ptr CChar
ps
        String -> IO ()
forall a. HasCallStack => String -> a
error (String
msgString -> String -> String
forall a. [a] -> [a] -> [a]
++String
": "String -> String -> String
forall a. [a] -> [a] -> [a]
++String
s)
    () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | description of GSL error codes
foreign import ccall unsafe "gsl_strerror" gsl_strerror :: CInt -> IO (Ptr CChar)

type PF = Ptr Float
type PD = Ptr Double
type PQ = Ptr (Complex Float)
type PC = Ptr (Complex Double)

type Res = IO CInt
type TV x  = CInt -> PD -> x
type TM x  = CInt -> CInt -> PD -> x
type TCV x = CInt -> PC -> x
type TCM x = CInt -> CInt -> PC -> x

type TVV = TV (TV Res)
type TVM = TV (TM Res)

ww2 :: (t -> (t -> b) -> b)
-> t -> (t -> (t -> t) -> b) -> t -> (t -> t -> t) -> b
ww2 t -> (t -> b) -> b
w1 t
o1 t -> (t -> t) -> b
w2 t
o2 t -> t -> t
f = t -> (t -> b) -> b
w1 t
o1 ((t -> b) -> b) -> (t -> b) -> b
forall a b. (a -> b) -> a -> b
$ \t
a1 -> t -> (t -> t) -> b
w2 t
o2 ((t -> t) -> b) -> (t -> t) -> b
forall a b. (a -> b) -> a -> b
$ \t
a2 -> t -> t -> t
f t
a1 t
a2

vec :: Vector a -> (((CInt -> Ptr a -> t) -> t) -> IO b) -> IO b
vec Vector a
x ((CInt -> Ptr a -> t) -> t) -> IO b
f = Vector a -> (Ptr a -> IO b) -> IO b
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
unsafeWith Vector a
x ((Ptr a -> IO b) -> IO b) -> (Ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr a
p -> do
    let v :: (CInt -> Ptr a -> t) -> t
v CInt -> Ptr a -> t
g = CInt -> Ptr a -> t
g (Int -> CInt
fi (Int -> CInt) -> Int -> CInt
forall a b. (a -> b) -> a -> b
$ Vector a -> Int
forall a. Storable a => Vector a -> Int
V.length Vector a
x) Ptr a
p
    ((CInt -> Ptr a -> t) -> t) -> IO b
f (CInt -> Ptr a -> t) -> t
forall {t}. (CInt -> Ptr a -> t) -> t
v
{-# INLINE vec #-}

infixl 1 #
c
a # :: c -> (b -> IO r) -> TransRaw c b -> IO r
# b -> IO r
b = c -> (b -> IO r) -> TransRaw c b -> IO r
forall c b r.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
forall b r. c -> (b -> IO r) -> TransRaw c b -> IO r
applyRaw c
a b -> IO r
b
{-# INLINE (#) #-}

--infixr 1 #
--a # b = apply a b
--{-# INLINE (#) #-}

c
a #! :: c -> (b -> IO r) -> TransRaw (TransRaw c b -> IO r) (IO r) -> IO r
#! b -> IO r
b = c
a c -> (b -> IO r) -> TransRaw c b -> IO r
forall {c} {b} {r}.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# b -> IO r
b (TransRaw c b -> IO r)
-> (IO r -> IO r) -> TransRaw (TransRaw c b -> IO r) (IO r) -> IO r
forall {c} {b} {r}.
TransArray c =>
c -> (b -> IO r) -> TransRaw c b -> IO r
# IO r -> IO r
forall a. a -> a
id
{-# INLINE (#!) #-}