Event loop!
This commit is contained in:
46
demos/interactionTrees/echo-server.tri
Normal file
46
demos/interactionTrees/echo-server.tri
Normal file
@@ -0,0 +1,46 @@
|
||||
!import "../../lib/base.tri" !Local
|
||||
!import "../../lib/io.tri" !Local
|
||||
!import "../../lib/socket.tri" !Local
|
||||
|
||||
-- Preserve the host-driver Result shape on error, run okCase on success.
|
||||
onOk = action okCase :
|
||||
bind action (result :
|
||||
matchResult
|
||||
(err rest : pure result)
|
||||
okCase
|
||||
result)
|
||||
|
||||
-- Convenience: print a string and continue.
|
||||
printLn = s : bind (putStr (append s "\n")) (_ : pure t)
|
||||
|
||||
-- Main accept+echo loop. Recursion via y.
|
||||
echoLoop = y (self server :
|
||||
bind (accept server) (acceptResult :
|
||||
matchResult
|
||||
(err rest :
|
||||
bind (printLn (append "accept error: " err)) (_ :
|
||||
self server))
|
||||
(accepted rest :
|
||||
matchPair
|
||||
(clientSock addr :
|
||||
bind (printLn (append "client from " addr)) (_ :
|
||||
bind (recv clientSock 4096) (msgResult :
|
||||
matchResult
|
||||
(err rest :
|
||||
bind (closeSocket clientSock) (_ :
|
||||
self server))
|
||||
(msg rest :
|
||||
bind (send clientSock msg) (_ :
|
||||
bind (closeSocket clientSock) (_ :
|
||||
self server)))
|
||||
msgResult)))
|
||||
accepted)
|
||||
acceptResult))
|
||||
|
||||
main = io (
|
||||
onOk socket (server rest :
|
||||
onOk (bindSocket server "127.0.0.1" 0) (_ rest :
|
||||
onOk (listen server 5) (_ rest :
|
||||
onOk (getSocketName server) (port rest :
|
||||
bind (printLn (append "Echo server listening on port " (showNumber port))) (_ :
|
||||
echoLoop server))))))
|
||||
63
lib/socket.tri
Normal file
63
lib/socket.tri
Normal file
@@ -0,0 +1,63 @@
|
||||
!import "base.tri" !Local
|
||||
!import "io.tri" !Local
|
||||
|
||||
-- Socket primitives for the IO driver.
|
||||
-- All actions return a Result tree (see lib/base.tri):
|
||||
-- ok value -- pair true (pair value t)
|
||||
-- err msg -- pair false (pair msg t)
|
||||
|
||||
socket = pair 70 t
|
||||
closeSocket = sock : pair 71 sock
|
||||
bindSocket = sock addr port : pair 72 (pair sock (pair addr port))
|
||||
listen = sock backlog : pair 73 (pair sock backlog)
|
||||
accept = sock : pair 74 sock
|
||||
connect = sock addr port : pair 75 (pair sock (pair addr port))
|
||||
recv = sock maxBytes : pair 76 (pair sock maxBytes)
|
||||
send = sock bytes : pair 77 (pair sock bytes)
|
||||
getSocketName = sock : pair 78 sock
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Convenience helpers
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
onSocket = (action errCase okCase :
|
||||
bind action (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Create a listening socket bound to an address and port.
|
||||
-- Returns ok listenSocket or err message.
|
||||
listenSocket = addr port backlog :
|
||||
bind (socket) (result :
|
||||
matchResult
|
||||
(err rest : pure (err "socket creation failed"))
|
||||
(sock rest :
|
||||
bind (bindSocket sock addr port) (bindResult :
|
||||
matchResult
|
||||
(err rest : pure (err "bind failed"))
|
||||
(_ rest :
|
||||
bind (listen sock backlog) (listenResult :
|
||||
matchResult
|
||||
(err rest : pure (err "listen failed"))
|
||||
(_ rest : pure (ok sock))
|
||||
listenResult))
|
||||
bindResult))
|
||||
result)
|
||||
|
||||
-- Accept a connection and return (clientSocket, peerAddr).
|
||||
-- The returned peerAddr is a string like "127.0.0.1:8080".
|
||||
onAccept = (sock errCase okCase :
|
||||
bind (accept sock) (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Receive all available bytes up to maxBytes.
|
||||
onRecv = (sock maxBytes errCase okCase :
|
||||
bind (recv sock maxBytes) (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Send bytes and return number of bytes sent.
|
||||
onSend = (sock bytes errCase okCase :
|
||||
bind (send sock bytes) (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Close a socket, ignoring errors.
|
||||
closeSocket_ = sock : bind (closeSocket sock) (_ : pure t)
|
||||
332
src/IODriver.hs
332
src/IODriver.hs
@@ -12,7 +12,7 @@ import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes,
|
||||
import qualified Data.ByteString as BS
|
||||
import System.IO (putStr, getLine)
|
||||
import qualified System.IO as IO
|
||||
import Control.Exception (try, IOException, SomeException)
|
||||
import Control.Exception (try, catch, IOException, SomeException)
|
||||
import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError)
|
||||
import Data.List (isPrefixOf)
|
||||
import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories)
|
||||
@@ -27,6 +27,8 @@ import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar,
|
||||
import qualified Data.Set as Set
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Foldable as Fold
|
||||
import qualified Network.Socket as NS
|
||||
import qualified Network.Socket.ByteString as NSB
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Permissions
|
||||
@@ -107,6 +109,9 @@ pureAction x = Fork (ofNumber 0) x
|
||||
invalidAsyncHandleResult :: T
|
||||
invalidAsyncHandleResult = errResult "invalid task handle"
|
||||
|
||||
invalidSocketHandleResult :: T
|
||||
invalidSocketHandleResult = errResult "invalid socket handle"
|
||||
|
||||
selfAwaitResult :: T
|
||||
selfAwaitResult = errResult "self await"
|
||||
|
||||
@@ -145,6 +150,45 @@ decodeTaskHandle tree =
|
||||
_ ->
|
||||
Left "invalid task handle"
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Socket identity and handles
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
newtype SockId = SockId Integer
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
sockHandle :: SockId -> T
|
||||
sockHandle (SockId n) =
|
||||
Fork (ofString "sock") (ofNumber n)
|
||||
|
||||
decodeSockHandle :: T -> Either String SockId
|
||||
decodeSockHandle tree =
|
||||
case tree of
|
||||
Fork tag nTree -> do
|
||||
tagString <- toString tag
|
||||
if tagString == "sock"
|
||||
then SockId <$> toNumber nTree
|
||||
else Left "invalid socket handle tag"
|
||||
_ ->
|
||||
Left "invalid socket handle"
|
||||
|
||||
getSocketPort :: NS.Socket -> IO (Maybe Integer)
|
||||
getSocketPort sock = do
|
||||
addr <- NS.getSocketName sock
|
||||
case addr of
|
||||
NS.SockAddrInet p _ -> return (Just (fromIntegral p))
|
||||
NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p))
|
||||
_ -> return Nothing
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Socket registry
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
data SocketRegistry = SocketRegistry
|
||||
{ sockMap :: Map SockId NS.Socket
|
||||
, sockNextId :: Integer
|
||||
}
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Free-monad action AST
|
||||
-- ---------------------------------------------------------------------------
|
||||
@@ -166,6 +210,15 @@ data Action
|
||||
| AAwait T
|
||||
| AYield
|
||||
| ASleep T
|
||||
| ASocket
|
||||
| ACloseSocket T
|
||||
| ABindSocket T T T
|
||||
| AListen T T
|
||||
| AAccept T
|
||||
| AConnect T T T
|
||||
| ARecv T T
|
||||
| ASend T T
|
||||
| AGetSocketName T
|
||||
deriving (Show)
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
@@ -200,6 +253,19 @@ tagAwait = 61
|
||||
tagYield = 62
|
||||
tagSleep = 63
|
||||
|
||||
tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer
|
||||
tagSocket = 70
|
||||
tagCloseSocket = 71
|
||||
tagBindSocket = 72
|
||||
tagListen = 73
|
||||
tagAccept = 74
|
||||
|
||||
tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer
|
||||
tagConnect = 75
|
||||
tagRecv = 76
|
||||
tagSend = 77
|
||||
tagGetSocketName = 78
|
||||
|
||||
data Step
|
||||
= Halt Runtime T
|
||||
| Continue Machine
|
||||
@@ -279,6 +345,43 @@ decodeAction tree =
|
||||
Right n | n == tagSleep ->
|
||||
Right (ASleep payload)
|
||||
|
||||
Right n | n == tagSocket ->
|
||||
Right ASocket
|
||||
|
||||
Right n | n == tagCloseSocket ->
|
||||
Right (ACloseSocket payload)
|
||||
|
||||
Right n | n == tagBindSocket ->
|
||||
case payload of
|
||||
Fork sock (Fork addr port) -> Right (ABindSocket sock addr port)
|
||||
_ -> Left "Invalid BindSocket: expected pair sock (pair addr port)"
|
||||
|
||||
Right n | n == tagListen ->
|
||||
case payload of
|
||||
Fork sock backlog -> Right (AListen sock backlog)
|
||||
_ -> Left "Invalid Listen: expected pair sock backlog"
|
||||
|
||||
Right n | n == tagAccept ->
|
||||
Right (AAccept payload)
|
||||
|
||||
Right n | n == tagConnect ->
|
||||
case payload of
|
||||
Fork sock (Fork addr port) -> Right (AConnect sock addr port)
|
||||
_ -> Left "Invalid Connect: expected pair sock (pair addr port)"
|
||||
|
||||
Right n | n == tagRecv ->
|
||||
case payload of
|
||||
Fork sock maxBytes -> Right (ARecv sock maxBytes)
|
||||
_ -> Left "Invalid Recv: expected pair sock maxBytes"
|
||||
|
||||
Right n | n == tagSend ->
|
||||
case payload of
|
||||
Fork sock bytes -> Right (ASend sock bytes)
|
||||
_ -> Left "Invalid Send: expected pair sock bytes"
|
||||
|
||||
Right n | n == tagGetSocketName ->
|
||||
Right (AGetSocketName payload)
|
||||
|
||||
Right n ->
|
||||
Left $ "Unknown IO action tag: " ++ show n
|
||||
|
||||
@@ -312,8 +415,8 @@ finishValue machine value =
|
||||
, machineFrames = rest
|
||||
})
|
||||
|
||||
stepMachine :: Machine -> IO Step
|
||||
stepMachine machine =
|
||||
stepMachine :: TVar SocketRegistry -> Machine -> IO Step
|
||||
stepMachine sockVar machine =
|
||||
case decodeAction (machineCurrent machine) of
|
||||
Right action -> dispatch action
|
||||
Left _ -> finishValue machine (errResult "invalid action")
|
||||
@@ -419,6 +522,207 @@ stepMachine machine =
|
||||
_ ->
|
||||
finishValue machine invalidSleepResult
|
||||
|
||||
ASocket -> do
|
||||
result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket)
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right sock -> do
|
||||
NS.setSocketOption sock NS.ReuseAddr 1
|
||||
sid <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
let sid = SockId next
|
||||
writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1))
|
||||
return sid
|
||||
finishValue machine (okResult (sockHandle sid))
|
||||
|
||||
ACloseSocket sockTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
case Map.lookup sid m of
|
||||
Nothing -> return Nothing
|
||||
Just sock -> do
|
||||
writeTVar sockVar (SocketRegistry (Map.delete sid m) next)
|
||||
return (Just sock)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
NS.close sock
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
ABindSocket sockTree addrTree portTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeString addrTree "BindSocket" of
|
||||
Left _ -> finishValue machine (errResult "invalid address")
|
||||
Right addrStr ->
|
||||
case toNumber portTree of
|
||||
Left _ -> finishValue machine (errResult "invalid port")
|
||||
Right port -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
result <- try (do
|
||||
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
||||
(Just addrStr)
|
||||
(Just (show port))
|
||||
let serverAddr = head addrInfo
|
||||
NS.bind sock (NS.addrAddress serverAddr)
|
||||
) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
AListen sockTree backlogTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case toNumber backlogTree of
|
||||
Left _ -> finishValue machine (errResult "invalid backlog")
|
||||
Right backlog -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
AAccept listenTree ->
|
||||
case decodeSockHandle listenTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right listenSid ->
|
||||
pure (AsyncAction (do
|
||||
mListenSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup listenSid m)
|
||||
case mListenSock of
|
||||
Nothing -> return (errResult "invalid socket handle")
|
||||
Just listenSock -> do
|
||||
result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr))
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right (clientSock, addr) -> do
|
||||
clientSid <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
let sid = SockId next
|
||||
writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1))
|
||||
return sid
|
||||
let addrStr = case addr of
|
||||
NS.SockAddrInet p h ->
|
||||
let (a,b,c,d) = NS.hostAddressToTuple h
|
||||
in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p
|
||||
_ -> show addr
|
||||
return (okResult (Fork (sockHandle clientSid) (ofString addrStr)))
|
||||
) machine)
|
||||
|
||||
AConnect sockTree addrTree portTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeString addrTree "Connect" of
|
||||
Left _ -> finishValue machine (errResult "invalid address")
|
||||
Right addrStr ->
|
||||
case toNumber portTree of
|
||||
Left _ -> finishValue machine (errResult "invalid port")
|
||||
Right port -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (do
|
||||
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
||||
(Just addrStr)
|
||||
(Just (show port))
|
||||
let serverAddr = head addrInfo
|
||||
NS.connect sock (NS.addrAddress serverAddr)
|
||||
) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
return (okResult Leaf)
|
||||
) machine)
|
||||
|
||||
ARecv sockTree maxBytesTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case toNumber maxBytesTree of
|
||||
Left _ -> finishValue machine (errResult "invalid maxBytes")
|
||||
Right maxBytes -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString)
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right bs ->
|
||||
if BS.null bs
|
||||
then return (errResult "connection closed")
|
||||
else return (okResult (ofBytes bs))
|
||||
) machine)
|
||||
|
||||
AGetSocketName sockTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
mPort <- getSocketPort sock
|
||||
case mPort of
|
||||
Just port -> finishValue machine (okResult (ofNumber port))
|
||||
Nothing -> finishValue machine (errResult "io error: could not get socket name")
|
||||
|
||||
ASend sockTree bytesTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeBytes bytesTree "Send" of
|
||||
Left _ -> finishValue machine (errResult "invalid bytes")
|
||||
Right bs -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (NSB.send sock bs) :: IO (Either SomeException Int)
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right sent ->
|
||||
return (okResult (ofNumber (fromIntegral sent)))
|
||||
) machine)
|
||||
|
||||
-- Permission and IO helpers
|
||||
checkReadPerm p =
|
||||
if allowReadAll (rtPerms (machineRuntime machine))
|
||||
@@ -543,6 +847,8 @@ data Scheduler = Scheduler
|
||||
, schedulerSleepQueue :: Map UTCTime (Set TaskId)
|
||||
, schedulerAsyncCompleted :: TVar (Map TaskId T)
|
||||
, schedulerCompleted :: Map TaskId (T, T)
|
||||
, schedulerSockets :: TVar SocketRegistry
|
||||
, schedulerNextSockId :: Integer
|
||||
}
|
||||
|
||||
instance Show Scheduler where
|
||||
@@ -553,10 +859,12 @@ instance Show Scheduler where
|
||||
++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s)
|
||||
++ ", schedulerAsyncCompleted = <tvar>"
|
||||
++ ", schedulerCompleted = " ++ show (schedulerCompleted s)
|
||||
++ ", schedulerSockets = <tvar>"
|
||||
++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s)
|
||||
++ " }"
|
||||
|
||||
initialScheduler :: TVar (Map TaskId T) -> Machine -> Scheduler
|
||||
initialScheduler asyncVar mainMachine =
|
||||
initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler
|
||||
initialScheduler asyncVar sockVar mainMachine =
|
||||
Scheduler
|
||||
{ schedulerNextTaskId = 1
|
||||
, schedulerRunnable = Seq.singleton (TaskId 0)
|
||||
@@ -565,6 +873,8 @@ initialScheduler asyncVar mainMachine =
|
||||
, schedulerSleepQueue = Map.empty
|
||||
, schedulerAsyncCompleted = asyncVar
|
||||
, schedulerCompleted = Map.empty
|
||||
, schedulerSockets = sockVar
|
||||
, schedulerNextSockId = 0
|
||||
}
|
||||
|
||||
runtimeOfStatus :: TaskStatus -> Maybe Runtime
|
||||
@@ -725,8 +1035,11 @@ handleStep taskId (SleepRequested ms machine) scheduler = do
|
||||
|
||||
handleStep taskId (AsyncAction ioAction machine) scheduler = do
|
||||
_ <- forkIO $ do
|
||||
result <- ioAction
|
||||
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId result)
|
||||
result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e))
|
||||
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $
|
||||
case result of
|
||||
Right val -> val
|
||||
Left msg -> errResult msg)
|
||||
pure scheduler
|
||||
{ schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler)
|
||||
}
|
||||
@@ -784,7 +1097,7 @@ schedulerStep scheduler = do
|
||||
taskId :< restQueue ->
|
||||
case Map.lookup taskId (schedulerTasks scheduler1) of
|
||||
Just (Runnable machine) -> do
|
||||
step <- stepMachine machine
|
||||
step <- stepMachine (schedulerSockets scheduler1) machine
|
||||
handleStep taskId step scheduler1 { schedulerRunnable = restQueue }
|
||||
|
||||
_ ->
|
||||
@@ -809,6 +1122,7 @@ runIOWith perms env initialState action =
|
||||
Left err -> pure (Left err)
|
||||
Right (_, action') -> do
|
||||
asyncVar <- newTVarIO Map.empty
|
||||
sockVar <- newTVarIO (SocketRegistry Map.empty 0)
|
||||
let initialMachine = Machine
|
||||
{ machineRuntime = Runtime
|
||||
{ rtPerms = perms
|
||||
@@ -818,7 +1132,7 @@ runIOWith perms env initialState action =
|
||||
, machineCurrent = action'
|
||||
, machineFrames = []
|
||||
}
|
||||
Right <$> runScheduler (initialScheduler asyncVar initialMachine)
|
||||
Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine)
|
||||
|
||||
runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T)
|
||||
runIOWithEnv perms env action = do
|
||||
|
||||
189
test/Spec.hs
189
test/Spec.hs
@@ -10,7 +10,8 @@ import Wire
|
||||
import ContentStore
|
||||
import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms)
|
||||
|
||||
import Control.Exception (evaluate, try, SomeException)
|
||||
import Control.Exception (bracket, evaluate, try, SomeException)
|
||||
import qualified Network.Socket as NS
|
||||
import Control.Monad (forM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
@@ -1918,12 +1919,196 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))"
|
||||
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")")
|
||||
final @?= ofString "done"
|
||||
|
||||
, testGroup "Socket primitives"
|
||||
[ testCase "socket returns ok result with valid handle" $ do
|
||||
final <- runIOSource "main = io socket"
|
||||
final @?= ioOkResult (Fork (ofString "sock") (ofNumber 0))
|
||||
|
||||
, testCase "closeSocket on invalid handle returns error" $ do
|
||||
final <- runIOSource "main = io (closeSocket (pair \"sock\" 99999))"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "bindSocket and listen succeed on loopback port 0" $ do
|
||||
final <- runIOSource "main = io (bind socket (result : matchResult (err rest : pure result) (sock rest : bind (bindSocket sock \"127.0.0.1\" 0) (bindResult : matchResult (err rest : pure bindResult) (_ rest : bind (listen sock 1) (listenResult : pure listenResult)) bindResult)) result))"
|
||||
final @?= ioOkResult Leaf
|
||||
|
||||
, testCase "connect to non-listening port returns error" $ do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "main = io (bind socket (result :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure \"socket-err\")"
|
||||
, " (sock rest : connect sock \"127.0.0.1\" 1)"
|
||||
, " result))"
|
||||
]
|
||||
case final of
|
||||
Fork Leaf (Fork _ Leaf) -> return ()
|
||||
other -> assertFailure $ "Expected error result, got: " ++ show other
|
||||
|
||||
, testCase "accept and recv receive bytes from forked client" $
|
||||
withFreePort $ \port -> do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "client = port :"
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (sock rest :"
|
||||
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
|
||||
, " preserveResult connectResult (_ rest :"
|
||||
, " send sock [104 105]))))"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " bind (listen server 1) (listenResult :"
|
||||
, " preserveResult listenResult (_ rest :"
|
||||
, " bind (fork (client " ++ show port ++ ")) (_ :"
|
||||
, " bind (accept server) (acceptResult :"
|
||||
, " preserveResult acceptResult (accepted rest :"
|
||||
, " matchPair"
|
||||
, " (clientSock addr : recv clientSock 2)"
|
||||
, " accepted))))))))))"
|
||||
]
|
||||
final @?= ioOkResult (ofBytes (BS.pack [104, 105]))
|
||||
|
||||
, testCase "client recv receives server response via accepted socket" $
|
||||
withFreePort $ \port -> do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "serverTask = server :"
|
||||
, " bind (accept server) (acceptResult :"
|
||||
, " preserveResult acceptResult (accepted rest :"
|
||||
, " matchPair"
|
||||
, " (clientSock addr :"
|
||||
, " bind (recv clientSock 4) (msgResult :"
|
||||
, " preserveResult msgResult (_ rest :"
|
||||
, " send clientSock [112 111 110 103])))"
|
||||
, " accepted))"
|
||||
, ""
|
||||
, "clientTask = port :"
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (sock rest :"
|
||||
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
|
||||
, " preserveResult connectResult (_ rest :"
|
||||
, " bind (send sock [112 105 110 103]) (_ :"
|
||||
, " recv sock 4)))))"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " bind (listen server 1) (listenResult :"
|
||||
, " preserveResult listenResult (_ rest :"
|
||||
, " bind (fork (serverTask server)) (_ :"
|
||||
, " clientTask " ++ show port ++ "))))))))"
|
||||
]
|
||||
final @?= ioOkResult (ofBytes (BS.pack [112, 111, 110, 103]))
|
||||
|
||||
, testCase "recv on closed peer returns connection closed" $
|
||||
withFreePort $ \port -> do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "clientTask = port :"
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (sock rest :"
|
||||
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
|
||||
, " preserveResult connectResult (_ rest :"
|
||||
, " closeSocket sock))))"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " bind (listen server 1) (listenResult :"
|
||||
, " preserveResult listenResult (_ rest :"
|
||||
, " bind (fork (clientTask " ++ show port ++ ")) (_ :"
|
||||
, " bind (accept server) (acceptResult :"
|
||||
, " preserveResult acceptResult (accepted rest :"
|
||||
, " matchPair"
|
||||
, " (clientSock addr :"
|
||||
, " bind (yield) (_ :"
|
||||
, " recv clientSock 1))"
|
||||
, " accepted))))))))))"
|
||||
]
|
||||
final @?= ioErrResult "connection closed"
|
||||
|
||||
, testCase "accept invalid socket handle returns error" $ do
|
||||
final <- runIOSource "main = io (accept (pair \"sock\" 99999))"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "recv invalid socket handle returns error" $ do
|
||||
final <- runIOSource "main = io (recv (pair \"sock\" 99999) 1)"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "send invalid socket handle returns error" $ do
|
||||
final <- runIOSource "main = io (send (pair \"sock\" 99999) [(1)])"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "getSocketName returns positive port after bind 0" $ do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" 0) (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " getSocketName server)))))"
|
||||
]
|
||||
case final of
|
||||
Fork (Stem Leaf) (Fork val Leaf) ->
|
||||
case toNumber val of
|
||||
Right port | port > 0 -> return ()
|
||||
Right 0 -> assertFailure "Expected positive port, got 0"
|
||||
Left _ -> assertFailure $ "Expected numeric port, got: " ++ show val
|
||||
other -> assertFailure $ "Expected ok result, got: " ++ show other
|
||||
]
|
||||
]
|
||||
|
||||
withFreePort :: (Int -> IO a) -> IO a
|
||||
withFreePort action =
|
||||
bracket
|
||||
(NS.socket NS.AF_INET NS.Stream NS.defaultProtocol)
|
||||
NS.close
|
||||
(\s -> do
|
||||
NS.setSocketOption s NS.ReuseAddr 1
|
||||
NS.bind s (NS.SockAddrInet 0 (NS.tupleToHostAddress (127, 0, 0, 1)))
|
||||
port <- NS.socketPort s
|
||||
action (fromIntegral port))
|
||||
|
||||
runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T)
|
||||
runIOSourceWith perms readerEnv initialState source = do
|
||||
ioEnv <- evaluateFile "./lib/io.tri"
|
||||
evalEnv <- evalTricuWithStore Nothing ioEnv (parseTricu source)
|
||||
sockEnv <- evaluateFile "./lib/socket.tri"
|
||||
let combinedEnv = Map.union sockEnv ioEnv
|
||||
evalEnv <- evalTricuWithStore Nothing combinedEnv (parseTricu source)
|
||||
let fullTree = mainResult evalEnv
|
||||
result <- runIOWith perms readerEnv initialState fullTree
|
||||
case result of
|
||||
|
||||
@@ -52,6 +52,7 @@ executable tricu
|
||||
, megaparsec
|
||||
, memory
|
||||
, mtl
|
||||
, network
|
||||
, servant
|
||||
, sqlite-simple
|
||||
, stm
|
||||
@@ -102,6 +103,7 @@ benchmark tricu-bench
|
||||
, megaparsec
|
||||
, memory
|
||||
, mtl
|
||||
, network
|
||||
, sqlite-simple
|
||||
, text
|
||||
, time
|
||||
@@ -148,6 +150,7 @@ test-suite tricu-tests
|
||||
, megaparsec
|
||||
, memory
|
||||
, mtl
|
||||
, network
|
||||
, servant
|
||||
, sqlite-simple
|
||||
, stm
|
||||
|
||||
Reference in New Issue
Block a user