{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE NoImplicitPrelude, ScopedTypeVariables #-}
module Foreign.Marshal.Pool (
Pool,
newPool,
freePool,
withPool,
pooledMalloc,
pooledMallocBytes,
pooledRealloc,
pooledReallocBytes,
pooledMallocArray,
pooledMallocArray0,
pooledReallocArray,
pooledReallocArray0,
pooledNew,
pooledNewArray,
pooledNewArray0
) where
import GHC.Base ( Int, Monad(..), (.), liftM, not )
import GHC.Err ( undefined )
import GHC.Exception ( throw )
import GHC.IO ( IO, mask, catchAny )
import GHC.IORef ( IORef, newIORef, readIORef, writeIORef )
import GHC.List ( elem, length )
import GHC.Num ( Num(..) )
import Data.OldList ( delete )
import Foreign.Marshal.Alloc ( mallocBytes, reallocBytes, free )
import Foreign.Marshal.Array ( pokeArray, pokeArray0 )
import Foreign.Marshal.Error ( throwIf )
import Foreign.Ptr ( Ptr, castPtr )
import Foreign.Storable ( Storable(sizeOf, poke) )
newtype Pool = Pool (IORef [Ptr ()])
newPool :: IO Pool
newPool = liftM Pool (newIORef [])
freePool :: Pool -> IO ()
freePool (Pool pool) = readIORef pool >>= freeAll
where freeAll [] = return ()
freeAll (p:ps) = free p >> freeAll ps
withPool :: (Pool -> IO b) -> IO b
withPool act =
mask (\restore -> do
pool <- newPool
val <- catchAny
(restore (act pool))
(\e -> do freePool pool; throw e)
freePool pool
return val)
pooledMalloc :: forall a . Storable a => Pool -> IO (Ptr a)
pooledMalloc pool = pooledMallocBytes pool (sizeOf (undefined :: a))
pooledMallocBytes :: Pool -> Int -> IO (Ptr a)
pooledMallocBytes (Pool pool) size = do
ptr <- mallocBytes size
ptrs <- readIORef pool
writeIORef pool (ptr:ptrs)
return (castPtr ptr)
pooledRealloc :: forall a . Storable a => Pool -> Ptr a -> IO (Ptr a)
pooledRealloc pool ptr = pooledReallocBytes pool ptr (sizeOf (undefined :: a))
pooledReallocBytes :: Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocBytes (Pool pool) ptr size = do
let cPtr = castPtr ptr
_ <- throwIf (not . (cPtr `elem`)) (\_ -> "pointer not in pool") (readIORef pool)
newPtr <- reallocBytes cPtr size
ptrs <- readIORef pool
writeIORef pool (newPtr : delete cPtr ptrs)
return (castPtr newPtr)
pooledMallocArray :: forall a . Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray pool size =
pooledMallocBytes pool (size * sizeOf (undefined :: a))
pooledMallocArray0 :: Storable a => Pool -> Int -> IO (Ptr a)
pooledMallocArray0 pool size =
pooledMallocArray pool (size + 1)
pooledReallocArray :: forall a . Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray pool ptr size =
pooledReallocBytes pool ptr (size * sizeOf (undefined :: a))
pooledReallocArray0 :: Storable a => Pool -> Ptr a -> Int -> IO (Ptr a)
pooledReallocArray0 pool ptr size =
pooledReallocArray pool ptr (size + 1)
pooledNew :: Storable a => Pool -> a -> IO (Ptr a)
pooledNew pool val = do
ptr <- pooledMalloc pool
poke ptr val
return ptr
pooledNewArray :: Storable a => Pool -> [a] -> IO (Ptr a)
pooledNewArray pool vals = do
ptr <- pooledMallocArray pool (length vals)
pokeArray ptr vals
return ptr
pooledNewArray0 :: Storable a => Pool -> a -> [a] -> IO (Ptr a)
pooledNewArray0 pool marker vals = do
ptr <- pooledMallocArray0 pool (length vals)
pokeArray0 marker ptr vals
return ptr