Event loop!
This commit is contained in:
332
src/IODriver.hs
332
src/IODriver.hs
@@ -12,7 +12,7 @@ import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes,
|
||||
import qualified Data.ByteString as BS
|
||||
import System.IO (putStr, getLine)
|
||||
import qualified System.IO as IO
|
||||
import Control.Exception (try, IOException, SomeException)
|
||||
import Control.Exception (try, catch, IOException, SomeException)
|
||||
import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError)
|
||||
import Data.List (isPrefixOf)
|
||||
import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories)
|
||||
@@ -27,6 +27,8 @@ import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar,
|
||||
import qualified Data.Set as Set
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Foldable as Fold
|
||||
import qualified Network.Socket as NS
|
||||
import qualified Network.Socket.ByteString as NSB
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Permissions
|
||||
@@ -107,6 +109,9 @@ pureAction x = Fork (ofNumber 0) x
|
||||
invalidAsyncHandleResult :: T
|
||||
invalidAsyncHandleResult = errResult "invalid task handle"
|
||||
|
||||
invalidSocketHandleResult :: T
|
||||
invalidSocketHandleResult = errResult "invalid socket handle"
|
||||
|
||||
selfAwaitResult :: T
|
||||
selfAwaitResult = errResult "self await"
|
||||
|
||||
@@ -145,6 +150,45 @@ decodeTaskHandle tree =
|
||||
_ ->
|
||||
Left "invalid task handle"
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Socket identity and handles
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
newtype SockId = SockId Integer
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
sockHandle :: SockId -> T
|
||||
sockHandle (SockId n) =
|
||||
Fork (ofString "sock") (ofNumber n)
|
||||
|
||||
decodeSockHandle :: T -> Either String SockId
|
||||
decodeSockHandle tree =
|
||||
case tree of
|
||||
Fork tag nTree -> do
|
||||
tagString <- toString tag
|
||||
if tagString == "sock"
|
||||
then SockId <$> toNumber nTree
|
||||
else Left "invalid socket handle tag"
|
||||
_ ->
|
||||
Left "invalid socket handle"
|
||||
|
||||
getSocketPort :: NS.Socket -> IO (Maybe Integer)
|
||||
getSocketPort sock = do
|
||||
addr <- NS.getSocketName sock
|
||||
case addr of
|
||||
NS.SockAddrInet p _ -> return (Just (fromIntegral p))
|
||||
NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p))
|
||||
_ -> return Nothing
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Socket registry
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
data SocketRegistry = SocketRegistry
|
||||
{ sockMap :: Map SockId NS.Socket
|
||||
, sockNextId :: Integer
|
||||
}
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Free-monad action AST
|
||||
-- ---------------------------------------------------------------------------
|
||||
@@ -166,6 +210,15 @@ data Action
|
||||
| AAwait T
|
||||
| AYield
|
||||
| ASleep T
|
||||
| ASocket
|
||||
| ACloseSocket T
|
||||
| ABindSocket T T T
|
||||
| AListen T T
|
||||
| AAccept T
|
||||
| AConnect T T T
|
||||
| ARecv T T
|
||||
| ASend T T
|
||||
| AGetSocketName T
|
||||
deriving (Show)
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
@@ -200,6 +253,19 @@ tagAwait = 61
|
||||
tagYield = 62
|
||||
tagSleep = 63
|
||||
|
||||
tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer
|
||||
tagSocket = 70
|
||||
tagCloseSocket = 71
|
||||
tagBindSocket = 72
|
||||
tagListen = 73
|
||||
tagAccept = 74
|
||||
|
||||
tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer
|
||||
tagConnect = 75
|
||||
tagRecv = 76
|
||||
tagSend = 77
|
||||
tagGetSocketName = 78
|
||||
|
||||
data Step
|
||||
= Halt Runtime T
|
||||
| Continue Machine
|
||||
@@ -279,6 +345,43 @@ decodeAction tree =
|
||||
Right n | n == tagSleep ->
|
||||
Right (ASleep payload)
|
||||
|
||||
Right n | n == tagSocket ->
|
||||
Right ASocket
|
||||
|
||||
Right n | n == tagCloseSocket ->
|
||||
Right (ACloseSocket payload)
|
||||
|
||||
Right n | n == tagBindSocket ->
|
||||
case payload of
|
||||
Fork sock (Fork addr port) -> Right (ABindSocket sock addr port)
|
||||
_ -> Left "Invalid BindSocket: expected pair sock (pair addr port)"
|
||||
|
||||
Right n | n == tagListen ->
|
||||
case payload of
|
||||
Fork sock backlog -> Right (AListen sock backlog)
|
||||
_ -> Left "Invalid Listen: expected pair sock backlog"
|
||||
|
||||
Right n | n == tagAccept ->
|
||||
Right (AAccept payload)
|
||||
|
||||
Right n | n == tagConnect ->
|
||||
case payload of
|
||||
Fork sock (Fork addr port) -> Right (AConnect sock addr port)
|
||||
_ -> Left "Invalid Connect: expected pair sock (pair addr port)"
|
||||
|
||||
Right n | n == tagRecv ->
|
||||
case payload of
|
||||
Fork sock maxBytes -> Right (ARecv sock maxBytes)
|
||||
_ -> Left "Invalid Recv: expected pair sock maxBytes"
|
||||
|
||||
Right n | n == tagSend ->
|
||||
case payload of
|
||||
Fork sock bytes -> Right (ASend sock bytes)
|
||||
_ -> Left "Invalid Send: expected pair sock bytes"
|
||||
|
||||
Right n | n == tagGetSocketName ->
|
||||
Right (AGetSocketName payload)
|
||||
|
||||
Right n ->
|
||||
Left $ "Unknown IO action tag: " ++ show n
|
||||
|
||||
@@ -312,8 +415,8 @@ finishValue machine value =
|
||||
, machineFrames = rest
|
||||
})
|
||||
|
||||
stepMachine :: Machine -> IO Step
|
||||
stepMachine machine =
|
||||
stepMachine :: TVar SocketRegistry -> Machine -> IO Step
|
||||
stepMachine sockVar machine =
|
||||
case decodeAction (machineCurrent machine) of
|
||||
Right action -> dispatch action
|
||||
Left _ -> finishValue machine (errResult "invalid action")
|
||||
@@ -419,6 +522,207 @@ stepMachine machine =
|
||||
_ ->
|
||||
finishValue machine invalidSleepResult
|
||||
|
||||
ASocket -> do
|
||||
result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket)
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right sock -> do
|
||||
NS.setSocketOption sock NS.ReuseAddr 1
|
||||
sid <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
let sid = SockId next
|
||||
writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1))
|
||||
return sid
|
||||
finishValue machine (okResult (sockHandle sid))
|
||||
|
||||
ACloseSocket sockTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
case Map.lookup sid m of
|
||||
Nothing -> return Nothing
|
||||
Just sock -> do
|
||||
writeTVar sockVar (SocketRegistry (Map.delete sid m) next)
|
||||
return (Just sock)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
NS.close sock
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
ABindSocket sockTree addrTree portTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeString addrTree "BindSocket" of
|
||||
Left _ -> finishValue machine (errResult "invalid address")
|
||||
Right addrStr ->
|
||||
case toNumber portTree of
|
||||
Left _ -> finishValue machine (errResult "invalid port")
|
||||
Right port -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
result <- try (do
|
||||
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
||||
(Just addrStr)
|
||||
(Just (show port))
|
||||
let serverAddr = head addrInfo
|
||||
NS.bind sock (NS.addrAddress serverAddr)
|
||||
) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
AListen sockTree backlogTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case toNumber backlogTree of
|
||||
Left _ -> finishValue machine (errResult "invalid backlog")
|
||||
Right backlog -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
AAccept listenTree ->
|
||||
case decodeSockHandle listenTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right listenSid ->
|
||||
pure (AsyncAction (do
|
||||
mListenSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup listenSid m)
|
||||
case mListenSock of
|
||||
Nothing -> return (errResult "invalid socket handle")
|
||||
Just listenSock -> do
|
||||
result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr))
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right (clientSock, addr) -> do
|
||||
clientSid <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
let sid = SockId next
|
||||
writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1))
|
||||
return sid
|
||||
let addrStr = case addr of
|
||||
NS.SockAddrInet p h ->
|
||||
let (a,b,c,d) = NS.hostAddressToTuple h
|
||||
in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p
|
||||
_ -> show addr
|
||||
return (okResult (Fork (sockHandle clientSid) (ofString addrStr)))
|
||||
) machine)
|
||||
|
||||
AConnect sockTree addrTree portTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeString addrTree "Connect" of
|
||||
Left _ -> finishValue machine (errResult "invalid address")
|
||||
Right addrStr ->
|
||||
case toNumber portTree of
|
||||
Left _ -> finishValue machine (errResult "invalid port")
|
||||
Right port -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (do
|
||||
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
||||
(Just addrStr)
|
||||
(Just (show port))
|
||||
let serverAddr = head addrInfo
|
||||
NS.connect sock (NS.addrAddress serverAddr)
|
||||
) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
return (okResult Leaf)
|
||||
) machine)
|
||||
|
||||
ARecv sockTree maxBytesTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case toNumber maxBytesTree of
|
||||
Left _ -> finishValue machine (errResult "invalid maxBytes")
|
||||
Right maxBytes -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString)
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right bs ->
|
||||
if BS.null bs
|
||||
then return (errResult "connection closed")
|
||||
else return (okResult (ofBytes bs))
|
||||
) machine)
|
||||
|
||||
AGetSocketName sockTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
mPort <- getSocketPort sock
|
||||
case mPort of
|
||||
Just port -> finishValue machine (okResult (ofNumber port))
|
||||
Nothing -> finishValue machine (errResult "io error: could not get socket name")
|
||||
|
||||
ASend sockTree bytesTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeBytes bytesTree "Send" of
|
||||
Left _ -> finishValue machine (errResult "invalid bytes")
|
||||
Right bs -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (NSB.send sock bs) :: IO (Either SomeException Int)
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right sent ->
|
||||
return (okResult (ofNumber (fromIntegral sent)))
|
||||
) machine)
|
||||
|
||||
-- Permission and IO helpers
|
||||
checkReadPerm p =
|
||||
if allowReadAll (rtPerms (machineRuntime machine))
|
||||
@@ -543,6 +847,8 @@ data Scheduler = Scheduler
|
||||
, schedulerSleepQueue :: Map UTCTime (Set TaskId)
|
||||
, schedulerAsyncCompleted :: TVar (Map TaskId T)
|
||||
, schedulerCompleted :: Map TaskId (T, T)
|
||||
, schedulerSockets :: TVar SocketRegistry
|
||||
, schedulerNextSockId :: Integer
|
||||
}
|
||||
|
||||
instance Show Scheduler where
|
||||
@@ -553,10 +859,12 @@ instance Show Scheduler where
|
||||
++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s)
|
||||
++ ", schedulerAsyncCompleted = <tvar>"
|
||||
++ ", schedulerCompleted = " ++ show (schedulerCompleted s)
|
||||
++ ", schedulerSockets = <tvar>"
|
||||
++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s)
|
||||
++ " }"
|
||||
|
||||
initialScheduler :: TVar (Map TaskId T) -> Machine -> Scheduler
|
||||
initialScheduler asyncVar mainMachine =
|
||||
initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler
|
||||
initialScheduler asyncVar sockVar mainMachine =
|
||||
Scheduler
|
||||
{ schedulerNextTaskId = 1
|
||||
, schedulerRunnable = Seq.singleton (TaskId 0)
|
||||
@@ -565,6 +873,8 @@ initialScheduler asyncVar mainMachine =
|
||||
, schedulerSleepQueue = Map.empty
|
||||
, schedulerAsyncCompleted = asyncVar
|
||||
, schedulerCompleted = Map.empty
|
||||
, schedulerSockets = sockVar
|
||||
, schedulerNextSockId = 0
|
||||
}
|
||||
|
||||
runtimeOfStatus :: TaskStatus -> Maybe Runtime
|
||||
@@ -725,8 +1035,11 @@ handleStep taskId (SleepRequested ms machine) scheduler = do
|
||||
|
||||
handleStep taskId (AsyncAction ioAction machine) scheduler = do
|
||||
_ <- forkIO $ do
|
||||
result <- ioAction
|
||||
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId result)
|
||||
result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e))
|
||||
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $
|
||||
case result of
|
||||
Right val -> val
|
||||
Left msg -> errResult msg)
|
||||
pure scheduler
|
||||
{ schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler)
|
||||
}
|
||||
@@ -784,7 +1097,7 @@ schedulerStep scheduler = do
|
||||
taskId :< restQueue ->
|
||||
case Map.lookup taskId (schedulerTasks scheduler1) of
|
||||
Just (Runnable machine) -> do
|
||||
step <- stepMachine machine
|
||||
step <- stepMachine (schedulerSockets scheduler1) machine
|
||||
handleStep taskId step scheduler1 { schedulerRunnable = restQueue }
|
||||
|
||||
_ ->
|
||||
@@ -809,6 +1122,7 @@ runIOWith perms env initialState action =
|
||||
Left err -> pure (Left err)
|
||||
Right (_, action') -> do
|
||||
asyncVar <- newTVarIO Map.empty
|
||||
sockVar <- newTVarIO (SocketRegistry Map.empty 0)
|
||||
let initialMachine = Machine
|
||||
{ machineRuntime = Runtime
|
||||
{ rtPerms = perms
|
||||
@@ -818,7 +1132,7 @@ runIOWith perms env initialState action =
|
||||
, machineCurrent = action'
|
||||
, machineFrames = []
|
||||
}
|
||||
Right <$> runScheduler (initialScheduler asyncVar initialMachine)
|
||||
Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine)
|
||||
|
||||
runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T)
|
||||
runIOWithEnv perms env action = do
|
||||
|
||||
Reference in New Issue
Block a user