Event loop!

This commit is contained in:
2026-05-19 17:00:36 -05:00
parent 2e13583de3
commit 020fa769a9
5 changed files with 622 additions and 11 deletions

View File

@@ -0,0 +1,46 @@
!import "../../lib/base.tri" !Local
!import "../../lib/io.tri" !Local
!import "../../lib/socket.tri" !Local
-- Preserve the host-driver Result shape on error, run okCase on success.
onOk = action okCase :
bind action (result :
matchResult
(err rest : pure result)
okCase
result)
-- Convenience: print a string and continue.
printLn = s : bind (putStr (append s "\n")) (_ : pure t)
-- Main accept+echo loop. Recursion via y.
echoLoop = y (self server :
bind (accept server) (acceptResult :
matchResult
(err rest :
bind (printLn (append "accept error: " err)) (_ :
self server))
(accepted rest :
matchPair
(clientSock addr :
bind (printLn (append "client from " addr)) (_ :
bind (recv clientSock 4096) (msgResult :
matchResult
(err rest :
bind (closeSocket clientSock) (_ :
self server))
(msg rest :
bind (send clientSock msg) (_ :
bind (closeSocket clientSock) (_ :
self server)))
msgResult)))
accepted)
acceptResult))
main = io (
onOk socket (server rest :
onOk (bindSocket server "127.0.0.1" 0) (_ rest :
onOk (listen server 5) (_ rest :
onOk (getSocketName server) (port rest :
bind (printLn (append "Echo server listening on port " (showNumber port))) (_ :
echoLoop server))))))

63
lib/socket.tri Normal file
View File

@@ -0,0 +1,63 @@
!import "base.tri" !Local
!import "io.tri" !Local
-- Socket primitives for the IO driver.
-- All actions return a Result tree (see lib/base.tri):
-- ok value -- pair true (pair value t)
-- err msg -- pair false (pair msg t)
socket = pair 70 t
closeSocket = sock : pair 71 sock
bindSocket = sock addr port : pair 72 (pair sock (pair addr port))
listen = sock backlog : pair 73 (pair sock backlog)
accept = sock : pair 74 sock
connect = sock addr port : pair 75 (pair sock (pair addr port))
recv = sock maxBytes : pair 76 (pair sock maxBytes)
send = sock bytes : pair 77 (pair sock bytes)
getSocketName = sock : pair 78 sock
-- ---------------------------------------------------------------------------
-- Convenience helpers
-- ---------------------------------------------------------------------------
onSocket = (action errCase okCase :
bind action (result :
matchResult errCase okCase result))
-- Create a listening socket bound to an address and port.
-- Returns ok listenSocket or err message.
listenSocket = addr port backlog :
bind (socket) (result :
matchResult
(err rest : pure (err "socket creation failed"))
(sock rest :
bind (bindSocket sock addr port) (bindResult :
matchResult
(err rest : pure (err "bind failed"))
(_ rest :
bind (listen sock backlog) (listenResult :
matchResult
(err rest : pure (err "listen failed"))
(_ rest : pure (ok sock))
listenResult))
bindResult))
result)
-- Accept a connection and return (clientSocket, peerAddr).
-- The returned peerAddr is a string like "127.0.0.1:8080".
onAccept = (sock errCase okCase :
bind (accept sock) (result :
matchResult errCase okCase result))
-- Receive all available bytes up to maxBytes.
onRecv = (sock maxBytes errCase okCase :
bind (recv sock maxBytes) (result :
matchResult errCase okCase result))
-- Send bytes and return number of bytes sent.
onSend = (sock bytes errCase okCase :
bind (send sock bytes) (result :
matchResult errCase okCase result))
-- Close a socket, ignoring errors.
closeSocket_ = sock : bind (closeSocket sock) (_ : pure t)

View File

@@ -12,7 +12,7 @@ import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes,
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
import System.IO (putStr, getLine) import System.IO (putStr, getLine)
import qualified System.IO as IO import qualified System.IO as IO
import Control.Exception (try, IOException, SomeException) import Control.Exception (try, catch, IOException, SomeException)
import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError) import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError)
import Data.List (isPrefixOf) import Data.List (isPrefixOf)
import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories) import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories)
@@ -27,6 +27,8 @@ import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar,
import qualified Data.Set as Set import qualified Data.Set as Set
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Foldable as Fold import qualified Data.Foldable as Fold
import qualified Network.Socket as NS
import qualified Network.Socket.ByteString as NSB
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
-- Permissions -- Permissions
@@ -107,6 +109,9 @@ pureAction x = Fork (ofNumber 0) x
invalidAsyncHandleResult :: T invalidAsyncHandleResult :: T
invalidAsyncHandleResult = errResult "invalid task handle" invalidAsyncHandleResult = errResult "invalid task handle"
invalidSocketHandleResult :: T
invalidSocketHandleResult = errResult "invalid socket handle"
selfAwaitResult :: T selfAwaitResult :: T
selfAwaitResult = errResult "self await" selfAwaitResult = errResult "self await"
@@ -145,6 +150,45 @@ decodeTaskHandle tree =
_ -> _ ->
Left "invalid task handle" Left "invalid task handle"
-- ---------------------------------------------------------------------------
-- Socket identity and handles
-- ---------------------------------------------------------------------------
newtype SockId = SockId Integer
deriving (Eq, Ord, Show)
sockHandle :: SockId -> T
sockHandle (SockId n) =
Fork (ofString "sock") (ofNumber n)
decodeSockHandle :: T -> Either String SockId
decodeSockHandle tree =
case tree of
Fork tag nTree -> do
tagString <- toString tag
if tagString == "sock"
then SockId <$> toNumber nTree
else Left "invalid socket handle tag"
_ ->
Left "invalid socket handle"
getSocketPort :: NS.Socket -> IO (Maybe Integer)
getSocketPort sock = do
addr <- NS.getSocketName sock
case addr of
NS.SockAddrInet p _ -> return (Just (fromIntegral p))
NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p))
_ -> return Nothing
-- ---------------------------------------------------------------------------
-- Socket registry
-- ---------------------------------------------------------------------------
data SocketRegistry = SocketRegistry
{ sockMap :: Map SockId NS.Socket
, sockNextId :: Integer
}
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
-- Free-monad action AST -- Free-monad action AST
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
@@ -166,6 +210,15 @@ data Action
| AAwait T | AAwait T
| AYield | AYield
| ASleep T | ASleep T
| ASocket
| ACloseSocket T
| ABindSocket T T T
| AListen T T
| AAccept T
| AConnect T T T
| ARecv T T
| ASend T T
| AGetSocketName T
deriving (Show) deriving (Show)
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
@@ -200,6 +253,19 @@ tagAwait = 61
tagYield = 62 tagYield = 62
tagSleep = 63 tagSleep = 63
tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer
tagSocket = 70
tagCloseSocket = 71
tagBindSocket = 72
tagListen = 73
tagAccept = 74
tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer
tagConnect = 75
tagRecv = 76
tagSend = 77
tagGetSocketName = 78
data Step data Step
= Halt Runtime T = Halt Runtime T
| Continue Machine | Continue Machine
@@ -279,6 +345,43 @@ decodeAction tree =
Right n | n == tagSleep -> Right n | n == tagSleep ->
Right (ASleep payload) Right (ASleep payload)
Right n | n == tagSocket ->
Right ASocket
Right n | n == tagCloseSocket ->
Right (ACloseSocket payload)
Right n | n == tagBindSocket ->
case payload of
Fork sock (Fork addr port) -> Right (ABindSocket sock addr port)
_ -> Left "Invalid BindSocket: expected pair sock (pair addr port)"
Right n | n == tagListen ->
case payload of
Fork sock backlog -> Right (AListen sock backlog)
_ -> Left "Invalid Listen: expected pair sock backlog"
Right n | n == tagAccept ->
Right (AAccept payload)
Right n | n == tagConnect ->
case payload of
Fork sock (Fork addr port) -> Right (AConnect sock addr port)
_ -> Left "Invalid Connect: expected pair sock (pair addr port)"
Right n | n == tagRecv ->
case payload of
Fork sock maxBytes -> Right (ARecv sock maxBytes)
_ -> Left "Invalid Recv: expected pair sock maxBytes"
Right n | n == tagSend ->
case payload of
Fork sock bytes -> Right (ASend sock bytes)
_ -> Left "Invalid Send: expected pair sock bytes"
Right n | n == tagGetSocketName ->
Right (AGetSocketName payload)
Right n -> Right n ->
Left $ "Unknown IO action tag: " ++ show n Left $ "Unknown IO action tag: " ++ show n
@@ -312,8 +415,8 @@ finishValue machine value =
, machineFrames = rest , machineFrames = rest
}) })
stepMachine :: Machine -> IO Step stepMachine :: TVar SocketRegistry -> Machine -> IO Step
stepMachine machine = stepMachine sockVar machine =
case decodeAction (machineCurrent machine) of case decodeAction (machineCurrent machine) of
Right action -> dispatch action Right action -> dispatch action
Left _ -> finishValue machine (errResult "invalid action") Left _ -> finishValue machine (errResult "invalid action")
@@ -419,6 +522,207 @@ stepMachine machine =
_ -> _ ->
finishValue machine invalidSleepResult finishValue machine invalidSleepResult
ASocket -> do
result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket)
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right sock -> do
NS.setSocketOption sock NS.ReuseAddr 1
sid <- atomically $ do
SocketRegistry m next <- readTVar sockVar
let sid = SockId next
writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1))
return sid
finishValue machine (okResult (sockHandle sid))
ACloseSocket sockTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid -> do
mSock <- atomically $ do
SocketRegistry m next <- readTVar sockVar
case Map.lookup sid m of
Nothing -> return Nothing
Just sock -> do
writeTVar sockVar (SocketRegistry (Map.delete sid m) next)
return (Just sock)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
NS.close sock
finishValue machine (okResult Leaf)
ABindSocket sockTree addrTree portTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeString addrTree "BindSocket" of
Left _ -> finishValue machine (errResult "invalid address")
Right addrStr ->
case toNumber portTree of
Left _ -> finishValue machine (errResult "invalid port")
Right port -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
result <- try (do
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
(Just addrStr)
(Just (show port))
let serverAddr = head addrInfo
NS.bind sock (NS.addrAddress serverAddr)
) :: IO (Either SomeException ())
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right () ->
finishValue machine (okResult Leaf)
AListen sockTree backlogTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case toNumber backlogTree of
Left _ -> finishValue machine (errResult "invalid backlog")
Right backlog -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ())
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right () ->
finishValue machine (okResult Leaf)
AAccept listenTree ->
case decodeSockHandle listenTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right listenSid ->
pure (AsyncAction (do
mListenSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup listenSid m)
case mListenSock of
Nothing -> return (errResult "invalid socket handle")
Just listenSock -> do
result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr))
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right (clientSock, addr) -> do
clientSid <- atomically $ do
SocketRegistry m next <- readTVar sockVar
let sid = SockId next
writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1))
return sid
let addrStr = case addr of
NS.SockAddrInet p h ->
let (a,b,c,d) = NS.hostAddressToTuple h
in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p
_ -> show addr
return (okResult (Fork (sockHandle clientSid) (ofString addrStr)))
) machine)
AConnect sockTree addrTree portTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeString addrTree "Connect" of
Left _ -> finishValue machine (errResult "invalid address")
Right addrStr ->
case toNumber portTree of
Left _ -> finishValue machine (errResult "invalid port")
Right port -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (do
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
(Just addrStr)
(Just (show port))
let serverAddr = head addrInfo
NS.connect sock (NS.addrAddress serverAddr)
) :: IO (Either SomeException ())
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right () ->
return (okResult Leaf)
) machine)
ARecv sockTree maxBytesTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case toNumber maxBytesTree of
Left _ -> finishValue machine (errResult "invalid maxBytes")
Right maxBytes -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString)
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right bs ->
if BS.null bs
then return (errResult "connection closed")
else return (okResult (ofBytes bs))
) machine)
AGetSocketName sockTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
mPort <- getSocketPort sock
case mPort of
Just port -> finishValue machine (okResult (ofNumber port))
Nothing -> finishValue machine (errResult "io error: could not get socket name")
ASend sockTree bytesTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeBytes bytesTree "Send" of
Left _ -> finishValue machine (errResult "invalid bytes")
Right bs -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (NSB.send sock bs) :: IO (Either SomeException Int)
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right sent ->
return (okResult (ofNumber (fromIntegral sent)))
) machine)
-- Permission and IO helpers -- Permission and IO helpers
checkReadPerm p = checkReadPerm p =
if allowReadAll (rtPerms (machineRuntime machine)) if allowReadAll (rtPerms (machineRuntime machine))
@@ -543,6 +847,8 @@ data Scheduler = Scheduler
, schedulerSleepQueue :: Map UTCTime (Set TaskId) , schedulerSleepQueue :: Map UTCTime (Set TaskId)
, schedulerAsyncCompleted :: TVar (Map TaskId T) , schedulerAsyncCompleted :: TVar (Map TaskId T)
, schedulerCompleted :: Map TaskId (T, T) , schedulerCompleted :: Map TaskId (T, T)
, schedulerSockets :: TVar SocketRegistry
, schedulerNextSockId :: Integer
} }
instance Show Scheduler where instance Show Scheduler where
@@ -553,10 +859,12 @@ instance Show Scheduler where
++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s) ++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s)
++ ", schedulerAsyncCompleted = <tvar>" ++ ", schedulerAsyncCompleted = <tvar>"
++ ", schedulerCompleted = " ++ show (schedulerCompleted s) ++ ", schedulerCompleted = " ++ show (schedulerCompleted s)
++ ", schedulerSockets = <tvar>"
++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s)
++ " }" ++ " }"
initialScheduler :: TVar (Map TaskId T) -> Machine -> Scheduler initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler
initialScheduler asyncVar mainMachine = initialScheduler asyncVar sockVar mainMachine =
Scheduler Scheduler
{ schedulerNextTaskId = 1 { schedulerNextTaskId = 1
, schedulerRunnable = Seq.singleton (TaskId 0) , schedulerRunnable = Seq.singleton (TaskId 0)
@@ -565,6 +873,8 @@ initialScheduler asyncVar mainMachine =
, schedulerSleepQueue = Map.empty , schedulerSleepQueue = Map.empty
, schedulerAsyncCompleted = asyncVar , schedulerAsyncCompleted = asyncVar
, schedulerCompleted = Map.empty , schedulerCompleted = Map.empty
, schedulerSockets = sockVar
, schedulerNextSockId = 0
} }
runtimeOfStatus :: TaskStatus -> Maybe Runtime runtimeOfStatus :: TaskStatus -> Maybe Runtime
@@ -725,8 +1035,11 @@ handleStep taskId (SleepRequested ms machine) scheduler = do
handleStep taskId (AsyncAction ioAction machine) scheduler = do handleStep taskId (AsyncAction ioAction machine) scheduler = do
_ <- forkIO $ do _ <- forkIO $ do
result <- ioAction result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e))
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId result) atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $
case result of
Right val -> val
Left msg -> errResult msg)
pure scheduler pure scheduler
{ schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler) { schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler)
} }
@@ -784,7 +1097,7 @@ schedulerStep scheduler = do
taskId :< restQueue -> taskId :< restQueue ->
case Map.lookup taskId (schedulerTasks scheduler1) of case Map.lookup taskId (schedulerTasks scheduler1) of
Just (Runnable machine) -> do Just (Runnable machine) -> do
step <- stepMachine machine step <- stepMachine (schedulerSockets scheduler1) machine
handleStep taskId step scheduler1 { schedulerRunnable = restQueue } handleStep taskId step scheduler1 { schedulerRunnable = restQueue }
_ -> _ ->
@@ -809,6 +1122,7 @@ runIOWith perms env initialState action =
Left err -> pure (Left err) Left err -> pure (Left err)
Right (_, action') -> do Right (_, action') -> do
asyncVar <- newTVarIO Map.empty asyncVar <- newTVarIO Map.empty
sockVar <- newTVarIO (SocketRegistry Map.empty 0)
let initialMachine = Machine let initialMachine = Machine
{ machineRuntime = Runtime { machineRuntime = Runtime
{ rtPerms = perms { rtPerms = perms
@@ -818,7 +1132,7 @@ runIOWith perms env initialState action =
, machineCurrent = action' , machineCurrent = action'
, machineFrames = [] , machineFrames = []
} }
Right <$> runScheduler (initialScheduler asyncVar initialMachine) Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine)
runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T) runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T)
runIOWithEnv perms env action = do runIOWithEnv perms env action = do

View File

@@ -10,7 +10,8 @@ import Wire
import ContentStore import ContentStore
import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms) import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms)
import Control.Exception (evaluate, try, SomeException) import Control.Exception (bracket, evaluate, try, SomeException)
import qualified Network.Socket as NS
import Control.Monad (forM_) import Control.Monad (forM_)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import System.IO.Temp (withSystemTempDirectory) import System.IO.Temp (withSystemTempDirectory)
@@ -1918,12 +1919,196 @@ ioDriverTests = testGroup "IO driver tests"
build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))" build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))"
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")") (final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")")
final @?= ofString "done" final @?= ofString "done"
, testGroup "Socket primitives"
[ testCase "socket returns ok result with valid handle" $ do
final <- runIOSource "main = io socket"
final @?= ioOkResult (Fork (ofString "sock") (ofNumber 0))
, testCase "closeSocket on invalid handle returns error" $ do
final <- runIOSource "main = io (closeSocket (pair \"sock\" 99999))"
final @?= ioErrResult "invalid socket handle"
, testCase "bindSocket and listen succeed on loopback port 0" $ do
final <- runIOSource "main = io (bind socket (result : matchResult (err rest : pure result) (sock rest : bind (bindSocket sock \"127.0.0.1\" 0) (bindResult : matchResult (err rest : pure bindResult) (_ rest : bind (listen sock 1) (listenResult : pure listenResult)) bindResult)) result))"
final @?= ioOkResult Leaf
, testCase "connect to non-listening port returns error" $ do
final <- runIOSource $
unlines
[ "main = io (bind socket (result :"
, " matchResult"
, " (err rest : pure \"socket-err\")"
, " (sock rest : connect sock \"127.0.0.1\" 1)"
, " result))"
]
case final of
Fork Leaf (Fork _ Leaf) -> return ()
other -> assertFailure $ "Expected error result, got: " ++ show other
, testCase "accept and recv receive bytes from forked client" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "client = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " send sock [104 105]))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (client " ++ show port ++ ")) (_ :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr : recv clientSock 2)"
, " accepted))))))))))"
]
final @?= ioOkResult (ofBytes (BS.pack [104, 105]))
, testCase "client recv receives server response via accepted socket" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "serverTask = server :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr :"
, " bind (recv clientSock 4) (msgResult :"
, " preserveResult msgResult (_ rest :"
, " send clientSock [112 111 110 103])))"
, " accepted))"
, ""
, "clientTask = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " bind (send sock [112 105 110 103]) (_ :"
, " recv sock 4)))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (serverTask server)) (_ :"
, " clientTask " ++ show port ++ "))))))))"
]
final @?= ioOkResult (ofBytes (BS.pack [112, 111, 110, 103]))
, testCase "recv on closed peer returns connection closed" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "clientTask = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " closeSocket sock))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (clientTask " ++ show port ++ ")) (_ :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr :"
, " bind (yield) (_ :"
, " recv clientSock 1))"
, " accepted))))))))))"
]
final @?= ioErrResult "connection closed"
, testCase "accept invalid socket handle returns error" $ do
final <- runIOSource "main = io (accept (pair \"sock\" 99999))"
final @?= ioErrResult "invalid socket handle"
, testCase "recv invalid socket handle returns error" $ do
final <- runIOSource "main = io (recv (pair \"sock\" 99999) 1)"
final @?= ioErrResult "invalid socket handle"
, testCase "send invalid socket handle returns error" $ do
final <- runIOSource "main = io (send (pair \"sock\" 99999) [(1)])"
final @?= ioErrResult "invalid socket handle"
, testCase "getSocketName returns positive port after bind 0" $ do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" 0) (bindResult :"
, " preserveResult bindResult (_ rest :"
, " getSocketName server)))))"
]
case final of
Fork (Stem Leaf) (Fork val Leaf) ->
case toNumber val of
Right port | port > 0 -> return ()
Right 0 -> assertFailure "Expected positive port, got 0"
Left _ -> assertFailure $ "Expected numeric port, got: " ++ show val
other -> assertFailure $ "Expected ok result, got: " ++ show other
]
] ]
withFreePort :: (Int -> IO a) -> IO a
withFreePort action =
bracket
(NS.socket NS.AF_INET NS.Stream NS.defaultProtocol)
NS.close
(\s -> do
NS.setSocketOption s NS.ReuseAddr 1
NS.bind s (NS.SockAddrInet 0 (NS.tupleToHostAddress (127, 0, 0, 1)))
port <- NS.socketPort s
action (fromIntegral port))
runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T) runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T)
runIOSourceWith perms readerEnv initialState source = do runIOSourceWith perms readerEnv initialState source = do
ioEnv <- evaluateFile "./lib/io.tri" ioEnv <- evaluateFile "./lib/io.tri"
evalEnv <- evalTricuWithStore Nothing ioEnv (parseTricu source) sockEnv <- evaluateFile "./lib/socket.tri"
let combinedEnv = Map.union sockEnv ioEnv
evalEnv <- evalTricuWithStore Nothing combinedEnv (parseTricu source)
let fullTree = mainResult evalEnv let fullTree = mainResult evalEnv
result <- runIOWith perms readerEnv initialState fullTree result <- runIOWith perms readerEnv initialState fullTree
case result of case result of

View File

@@ -52,6 +52,7 @@ executable tricu
, megaparsec , megaparsec
, memory , memory
, mtl , mtl
, network
, servant , servant
, sqlite-simple , sqlite-simple
, stm , stm
@@ -102,6 +103,7 @@ benchmark tricu-bench
, megaparsec , megaparsec
, memory , memory
, mtl , mtl
, network
, sqlite-simple , sqlite-simple
, text , text
, time , time
@@ -148,6 +150,7 @@ test-suite tricu-tests
, megaparsec , megaparsec
, memory , memory
, mtl , mtl
, network
, servant , servant
, sqlite-simple , sqlite-simple
, stm , stm