From 020fa769a94c8d43ad1e197a6c20c3a3b6e5d45e Mon Sep 17 00:00:00 2001 From: James Eversole Date: Tue, 19 May 2026 17:00:36 -0500 Subject: [PATCH] Event loop! --- demos/interactionTrees/echo-server.tri | 46 ++++ lib/socket.tri | 63 +++++ src/IODriver.hs | 332 ++++++++++++++++++++++++- test/Spec.hs | 189 +++++++++++++- tricu.cabal | 3 + 5 files changed, 622 insertions(+), 11 deletions(-) create mode 100644 demos/interactionTrees/echo-server.tri create mode 100644 lib/socket.tri diff --git a/demos/interactionTrees/echo-server.tri b/demos/interactionTrees/echo-server.tri new file mode 100644 index 0000000..67f3050 --- /dev/null +++ b/demos/interactionTrees/echo-server.tri @@ -0,0 +1,46 @@ +!import "../../lib/base.tri" !Local +!import "../../lib/io.tri" !Local +!import "../../lib/socket.tri" !Local + +-- Preserve the host-driver Result shape on error, run okCase on success. +onOk = action okCase : + bind action (result : + matchResult + (err rest : pure result) + okCase + result) + +-- Convenience: print a string and continue. +printLn = s : bind (putStr (append s "\n")) (_ : pure t) + +-- Main accept+echo loop. Recursion via y. +echoLoop = y (self server : + bind (accept server) (acceptResult : + matchResult + (err rest : + bind (printLn (append "accept error: " err)) (_ : + self server)) + (accepted rest : + matchPair + (clientSock addr : + bind (printLn (append "client from " addr)) (_ : + bind (recv clientSock 4096) (msgResult : + matchResult + (err rest : + bind (closeSocket clientSock) (_ : + self server)) + (msg rest : + bind (send clientSock msg) (_ : + bind (closeSocket clientSock) (_ : + self server))) + msgResult))) + accepted) + acceptResult)) + +main = io ( + onOk socket (server rest : + onOk (bindSocket server "127.0.0.1" 0) (_ rest : + onOk (listen server 5) (_ rest : + onOk (getSocketName server) (port rest : + bind (printLn (append "Echo server listening on port " (showNumber port))) (_ : + echoLoop server)))))) diff --git a/lib/socket.tri b/lib/socket.tri new file mode 100644 index 0000000..defe11f --- /dev/null +++ b/lib/socket.tri @@ -0,0 +1,63 @@ +!import "base.tri" !Local +!import "io.tri" !Local + +-- Socket primitives for the IO driver. +-- All actions return a Result tree (see lib/base.tri): +-- ok value -- pair true (pair value t) +-- err msg -- pair false (pair msg t) + +socket = pair 70 t +closeSocket = sock : pair 71 sock +bindSocket = sock addr port : pair 72 (pair sock (pair addr port)) +listen = sock backlog : pair 73 (pair sock backlog) +accept = sock : pair 74 sock +connect = sock addr port : pair 75 (pair sock (pair addr port)) +recv = sock maxBytes : pair 76 (pair sock maxBytes) +send = sock bytes : pair 77 (pair sock bytes) +getSocketName = sock : pair 78 sock + +-- --------------------------------------------------------------------------- +-- Convenience helpers +-- --------------------------------------------------------------------------- + +onSocket = (action errCase okCase : + bind action (result : + matchResult errCase okCase result)) + +-- Create a listening socket bound to an address and port. +-- Returns ok listenSocket or err message. +listenSocket = addr port backlog : + bind (socket) (result : + matchResult + (err rest : pure (err "socket creation failed")) + (sock rest : + bind (bindSocket sock addr port) (bindResult : + matchResult + (err rest : pure (err "bind failed")) + (_ rest : + bind (listen sock backlog) (listenResult : + matchResult + (err rest : pure (err "listen failed")) + (_ rest : pure (ok sock)) + listenResult)) + bindResult)) + result) + +-- Accept a connection and return (clientSocket, peerAddr). +-- The returned peerAddr is a string like "127.0.0.1:8080". +onAccept = (sock errCase okCase : + bind (accept sock) (result : + matchResult errCase okCase result)) + +-- Receive all available bytes up to maxBytes. +onRecv = (sock maxBytes errCase okCase : + bind (recv sock maxBytes) (result : + matchResult errCase okCase result)) + +-- Send bytes and return number of bytes sent. +onSend = (sock bytes errCase okCase : + bind (send sock bytes) (result : + matchResult errCase okCase result)) + +-- Close a socket, ignoring errors. +closeSocket_ = sock : bind (closeSocket sock) (_ : pure t) diff --git a/src/IODriver.hs b/src/IODriver.hs index 9bb1ec9..e735a50 100644 --- a/src/IODriver.hs +++ b/src/IODriver.hs @@ -12,7 +12,7 @@ import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes, import qualified Data.ByteString as BS import System.IO (putStr, getLine) import qualified System.IO as IO -import Control.Exception (try, IOException, SomeException) +import Control.Exception (try, catch, IOException, SomeException) import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError) import Data.List (isPrefixOf) import System.FilePath (normalise, isRelative, (), addTrailingPathSeparator, splitDirectories) @@ -27,6 +27,8 @@ import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar, import qualified Data.Set as Set import Data.Set (Set) import qualified Data.Foldable as Fold +import qualified Network.Socket as NS +import qualified Network.Socket.ByteString as NSB -- --------------------------------------------------------------------------- -- Permissions @@ -107,6 +109,9 @@ pureAction x = Fork (ofNumber 0) x invalidAsyncHandleResult :: T invalidAsyncHandleResult = errResult "invalid task handle" +invalidSocketHandleResult :: T +invalidSocketHandleResult = errResult "invalid socket handle" + selfAwaitResult :: T selfAwaitResult = errResult "self await" @@ -145,6 +150,45 @@ decodeTaskHandle tree = _ -> Left "invalid task handle" +-- --------------------------------------------------------------------------- +-- Socket identity and handles +-- --------------------------------------------------------------------------- + +newtype SockId = SockId Integer + deriving (Eq, Ord, Show) + +sockHandle :: SockId -> T +sockHandle (SockId n) = + Fork (ofString "sock") (ofNumber n) + +decodeSockHandle :: T -> Either String SockId +decodeSockHandle tree = + case tree of + Fork tag nTree -> do + tagString <- toString tag + if tagString == "sock" + then SockId <$> toNumber nTree + else Left "invalid socket handle tag" + _ -> + Left "invalid socket handle" + +getSocketPort :: NS.Socket -> IO (Maybe Integer) +getSocketPort sock = do + addr <- NS.getSocketName sock + case addr of + NS.SockAddrInet p _ -> return (Just (fromIntegral p)) + NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p)) + _ -> return Nothing + +-- --------------------------------------------------------------------------- +-- Socket registry +-- --------------------------------------------------------------------------- + +data SocketRegistry = SocketRegistry + { sockMap :: Map SockId NS.Socket + , sockNextId :: Integer + } + -- --------------------------------------------------------------------------- -- Free-monad action AST -- --------------------------------------------------------------------------- @@ -166,6 +210,15 @@ data Action | AAwait T | AYield | ASleep T + | ASocket + | ACloseSocket T + | ABindSocket T T T + | AListen T T + | AAccept T + | AConnect T T T + | ARecv T T + | ASend T T + | AGetSocketName T deriving (Show) -- --------------------------------------------------------------------------- @@ -200,6 +253,19 @@ tagAwait = 61 tagYield = 62 tagSleep = 63 +tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer +tagSocket = 70 +tagCloseSocket = 71 +tagBindSocket = 72 +tagListen = 73 +tagAccept = 74 + +tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer +tagConnect = 75 +tagRecv = 76 +tagSend = 77 +tagGetSocketName = 78 + data Step = Halt Runtime T | Continue Machine @@ -279,6 +345,43 @@ decodeAction tree = Right n | n == tagSleep -> Right (ASleep payload) + Right n | n == tagSocket -> + Right ASocket + + Right n | n == tagCloseSocket -> + Right (ACloseSocket payload) + + Right n | n == tagBindSocket -> + case payload of + Fork sock (Fork addr port) -> Right (ABindSocket sock addr port) + _ -> Left "Invalid BindSocket: expected pair sock (pair addr port)" + + Right n | n == tagListen -> + case payload of + Fork sock backlog -> Right (AListen sock backlog) + _ -> Left "Invalid Listen: expected pair sock backlog" + + Right n | n == tagAccept -> + Right (AAccept payload) + + Right n | n == tagConnect -> + case payload of + Fork sock (Fork addr port) -> Right (AConnect sock addr port) + _ -> Left "Invalid Connect: expected pair sock (pair addr port)" + + Right n | n == tagRecv -> + case payload of + Fork sock maxBytes -> Right (ARecv sock maxBytes) + _ -> Left "Invalid Recv: expected pair sock maxBytes" + + Right n | n == tagSend -> + case payload of + Fork sock bytes -> Right (ASend sock bytes) + _ -> Left "Invalid Send: expected pair sock bytes" + + Right n | n == tagGetSocketName -> + Right (AGetSocketName payload) + Right n -> Left $ "Unknown IO action tag: " ++ show n @@ -312,8 +415,8 @@ finishValue machine value = , machineFrames = rest }) -stepMachine :: Machine -> IO Step -stepMachine machine = +stepMachine :: TVar SocketRegistry -> Machine -> IO Step +stepMachine sockVar machine = case decodeAction (machineCurrent machine) of Right action -> dispatch action Left _ -> finishValue machine (errResult "invalid action") @@ -419,6 +522,207 @@ stepMachine machine = _ -> finishValue machine invalidSleepResult + ASocket -> do + result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket) + case result of + Left e -> + finishValue machine (errResult ("io error: " ++ show e)) + Right sock -> do + NS.setSocketOption sock NS.ReuseAddr 1 + sid <- atomically $ do + SocketRegistry m next <- readTVar sockVar + let sid = SockId next + writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1)) + return sid + finishValue machine (okResult (sockHandle sid)) + + ACloseSocket sockTree -> + case decodeSockHandle sockTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right sid -> do + mSock <- atomically $ do + SocketRegistry m next <- readTVar sockVar + case Map.lookup sid m of + Nothing -> return Nothing + Just sock -> do + writeTVar sockVar (SocketRegistry (Map.delete sid m) next) + return (Just sock) + case mSock of + Nothing -> finishValue machine invalidSocketHandleResult + Just sock -> do + NS.close sock + finishValue machine (okResult Leaf) + + ABindSocket sockTree addrTree portTree -> + case decodeSockHandle sockTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right sid -> + case decodeString addrTree "BindSocket" of + Left _ -> finishValue machine (errResult "invalid address") + Right addrStr -> + case toNumber portTree of + Left _ -> finishValue machine (errResult "invalid port") + Right port -> do + mSock <- atomically $ do + SocketRegistry m _ <- readTVar sockVar + return (Map.lookup sid m) + case mSock of + Nothing -> finishValue machine invalidSocketHandleResult + Just sock -> do + result <- try (do + addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream }) + (Just addrStr) + (Just (show port)) + let serverAddr = head addrInfo + NS.bind sock (NS.addrAddress serverAddr) + ) :: IO (Either SomeException ()) + case result of + Left e -> + finishValue machine (errResult ("io error: " ++ show e)) + Right () -> + finishValue machine (okResult Leaf) + + AListen sockTree backlogTree -> + case decodeSockHandle sockTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right sid -> + case toNumber backlogTree of + Left _ -> finishValue machine (errResult "invalid backlog") + Right backlog -> do + mSock <- atomically $ do + SocketRegistry m _ <- readTVar sockVar + return (Map.lookup sid m) + case mSock of + Nothing -> finishValue machine invalidSocketHandleResult + Just sock -> do + result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ()) + case result of + Left e -> + finishValue machine (errResult ("io error: " ++ show e)) + Right () -> + finishValue machine (okResult Leaf) + + AAccept listenTree -> + case decodeSockHandle listenTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right listenSid -> + pure (AsyncAction (do + mListenSock <- atomically $ do + SocketRegistry m _ <- readTVar sockVar + return (Map.lookup listenSid m) + case mListenSock of + Nothing -> return (errResult "invalid socket handle") + Just listenSock -> do + result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr)) + case result of + Left e -> + return (errResult ("io error: " ++ show e)) + Right (clientSock, addr) -> do + clientSid <- atomically $ do + SocketRegistry m next <- readTVar sockVar + let sid = SockId next + writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1)) + return sid + let addrStr = case addr of + NS.SockAddrInet p h -> + let (a,b,c,d) = NS.hostAddressToTuple h + in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p + _ -> show addr + return (okResult (Fork (sockHandle clientSid) (ofString addrStr))) + ) machine) + + AConnect sockTree addrTree portTree -> + case decodeSockHandle sockTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right sid -> + case decodeString addrTree "Connect" of + Left _ -> finishValue machine (errResult "invalid address") + Right addrStr -> + case toNumber portTree of + Left _ -> finishValue machine (errResult "invalid port") + Right port -> do + mSock <- atomically $ do + SocketRegistry m _ <- readTVar sockVar + return (Map.lookup sid m) + case mSock of + Nothing -> finishValue machine invalidSocketHandleResult + Just sock -> + pure (AsyncAction (do + result <- try (do + addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream }) + (Just addrStr) + (Just (show port)) + let serverAddr = head addrInfo + NS.connect sock (NS.addrAddress serverAddr) + ) :: IO (Either SomeException ()) + case result of + Left e -> + return (errResult ("io error: " ++ show e)) + Right () -> + return (okResult Leaf) + ) machine) + + ARecv sockTree maxBytesTree -> + case decodeSockHandle sockTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right sid -> + case toNumber maxBytesTree of + Left _ -> finishValue machine (errResult "invalid maxBytes") + Right maxBytes -> do + mSock <- atomically $ do + SocketRegistry m _ <- readTVar sockVar + return (Map.lookup sid m) + case mSock of + Nothing -> finishValue machine invalidSocketHandleResult + Just sock -> + pure (AsyncAction (do + result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString) + case result of + Left e -> + return (errResult ("io error: " ++ show e)) + Right bs -> + if BS.null bs + then return (errResult "connection closed") + else return (okResult (ofBytes bs)) + ) machine) + + AGetSocketName sockTree -> + case decodeSockHandle sockTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right sid -> do + mSock <- atomically $ do + SocketRegistry m _ <- readTVar sockVar + return (Map.lookup sid m) + case mSock of + Nothing -> finishValue machine invalidSocketHandleResult + Just sock -> do + mPort <- getSocketPort sock + case mPort of + Just port -> finishValue machine (okResult (ofNumber port)) + Nothing -> finishValue machine (errResult "io error: could not get socket name") + + ASend sockTree bytesTree -> + case decodeSockHandle sockTree of + Left _ -> finishValue machine invalidSocketHandleResult + Right sid -> + case decodeBytes bytesTree "Send" of + Left _ -> finishValue machine (errResult "invalid bytes") + Right bs -> do + mSock <- atomically $ do + SocketRegistry m _ <- readTVar sockVar + return (Map.lookup sid m) + case mSock of + Nothing -> finishValue machine invalidSocketHandleResult + Just sock -> + pure (AsyncAction (do + result <- try (NSB.send sock bs) :: IO (Either SomeException Int) + case result of + Left e -> + return (errResult ("io error: " ++ show e)) + Right sent -> + return (okResult (ofNumber (fromIntegral sent))) + ) machine) + -- Permission and IO helpers checkReadPerm p = if allowReadAll (rtPerms (machineRuntime machine)) @@ -543,6 +847,8 @@ data Scheduler = Scheduler , schedulerSleepQueue :: Map UTCTime (Set TaskId) , schedulerAsyncCompleted :: TVar (Map TaskId T) , schedulerCompleted :: Map TaskId (T, T) + , schedulerSockets :: TVar SocketRegistry + , schedulerNextSockId :: Integer } instance Show Scheduler where @@ -553,10 +859,12 @@ instance Show Scheduler where ++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s) ++ ", schedulerAsyncCompleted = " ++ ", schedulerCompleted = " ++ show (schedulerCompleted s) + ++ ", schedulerSockets = " + ++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s) ++ " }" -initialScheduler :: TVar (Map TaskId T) -> Machine -> Scheduler -initialScheduler asyncVar mainMachine = +initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler +initialScheduler asyncVar sockVar mainMachine = Scheduler { schedulerNextTaskId = 1 , schedulerRunnable = Seq.singleton (TaskId 0) @@ -565,6 +873,8 @@ initialScheduler asyncVar mainMachine = , schedulerSleepQueue = Map.empty , schedulerAsyncCompleted = asyncVar , schedulerCompleted = Map.empty + , schedulerSockets = sockVar + , schedulerNextSockId = 0 } runtimeOfStatus :: TaskStatus -> Maybe Runtime @@ -725,8 +1035,11 @@ handleStep taskId (SleepRequested ms machine) scheduler = do handleStep taskId (AsyncAction ioAction machine) scheduler = do _ <- forkIO $ do - result <- ioAction - atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId result) + result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e)) + atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $ + case result of + Right val -> val + Left msg -> errResult msg) pure scheduler { schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler) } @@ -784,7 +1097,7 @@ schedulerStep scheduler = do taskId :< restQueue -> case Map.lookup taskId (schedulerTasks scheduler1) of Just (Runnable machine) -> do - step <- stepMachine machine + step <- stepMachine (schedulerSockets scheduler1) machine handleStep taskId step scheduler1 { schedulerRunnable = restQueue } _ -> @@ -809,6 +1122,7 @@ runIOWith perms env initialState action = Left err -> pure (Left err) Right (_, action') -> do asyncVar <- newTVarIO Map.empty + sockVar <- newTVarIO (SocketRegistry Map.empty 0) let initialMachine = Machine { machineRuntime = Runtime { rtPerms = perms @@ -818,7 +1132,7 @@ runIOWith perms env initialState action = , machineCurrent = action' , machineFrames = [] } - Right <$> runScheduler (initialScheduler asyncVar initialMachine) + Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine) runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T) runIOWithEnv perms env action = do diff --git a/test/Spec.hs b/test/Spec.hs index 8d14580..feab34e 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -10,7 +10,8 @@ import Wire import ContentStore import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms) -import Control.Exception (evaluate, try, SomeException) +import Control.Exception (bracket, evaluate, try, SomeException) +import qualified Network.Socket as NS import Control.Monad (forM_) import Control.Monad.IO.Class (liftIO) import System.IO.Temp (withSystemTempDirectory) @@ -1918,12 +1919,196 @@ ioDriverTests = testGroup "IO driver tests" build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))" (final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")") final @?= ofString "done" + + , testGroup "Socket primitives" + [ testCase "socket returns ok result with valid handle" $ do + final <- runIOSource "main = io socket" + final @?= ioOkResult (Fork (ofString "sock") (ofNumber 0)) + + , testCase "closeSocket on invalid handle returns error" $ do + final <- runIOSource "main = io (closeSocket (pair \"sock\" 99999))" + final @?= ioErrResult "invalid socket handle" + + , testCase "bindSocket and listen succeed on loopback port 0" $ do + final <- runIOSource "main = io (bind socket (result : matchResult (err rest : pure result) (sock rest : bind (bindSocket sock \"127.0.0.1\" 0) (bindResult : matchResult (err rest : pure bindResult) (_ rest : bind (listen sock 1) (listenResult : pure listenResult)) bindResult)) result))" + final @?= ioOkResult Leaf + + , testCase "connect to non-listening port returns error" $ do + final <- runIOSource $ + unlines + [ "main = io (bind socket (result :" + , " matchResult" + , " (err rest : pure \"socket-err\")" + , " (sock rest : connect sock \"127.0.0.1\" 1)" + , " result))" + ] + case final of + Fork Leaf (Fork _ Leaf) -> return () + other -> assertFailure $ "Expected error result, got: " ++ show other + + , testCase "accept and recv receive bytes from forked client" $ + withFreePort $ \port -> do + final <- runIOSource $ + unlines + [ "preserveResult = (result okCase :" + , " matchResult" + , " (err rest : pure result)" + , " okCase" + , " result)" + , "" + , "client = port :" + , " bind socket (result :" + , " preserveResult result (sock rest :" + , " bind (connect sock \"127.0.0.1\" port) (connectResult :" + , " preserveResult connectResult (_ rest :" + , " send sock [104 105]))))" + , "" + , "main = io (" + , " bind socket (result :" + , " preserveResult result (server rest :" + , " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :" + , " preserveResult bindResult (_ rest :" + , " bind (listen server 1) (listenResult :" + , " preserveResult listenResult (_ rest :" + , " bind (fork (client " ++ show port ++ ")) (_ :" + , " bind (accept server) (acceptResult :" + , " preserveResult acceptResult (accepted rest :" + , " matchPair" + , " (clientSock addr : recv clientSock 2)" + , " accepted))))))))))" + ] + final @?= ioOkResult (ofBytes (BS.pack [104, 105])) + + , testCase "client recv receives server response via accepted socket" $ + withFreePort $ \port -> do + final <- runIOSource $ + unlines + [ "preserveResult = (result okCase :" + , " matchResult" + , " (err rest : pure result)" + , " okCase" + , " result)" + , "" + , "serverTask = server :" + , " bind (accept server) (acceptResult :" + , " preserveResult acceptResult (accepted rest :" + , " matchPair" + , " (clientSock addr :" + , " bind (recv clientSock 4) (msgResult :" + , " preserveResult msgResult (_ rest :" + , " send clientSock [112 111 110 103])))" + , " accepted))" + , "" + , "clientTask = port :" + , " bind socket (result :" + , " preserveResult result (sock rest :" + , " bind (connect sock \"127.0.0.1\" port) (connectResult :" + , " preserveResult connectResult (_ rest :" + , " bind (send sock [112 105 110 103]) (_ :" + , " recv sock 4)))))" + , "" + , "main = io (" + , " bind socket (result :" + , " preserveResult result (server rest :" + , " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :" + , " preserveResult bindResult (_ rest :" + , " bind (listen server 1) (listenResult :" + , " preserveResult listenResult (_ rest :" + , " bind (fork (serverTask server)) (_ :" + , " clientTask " ++ show port ++ "))))))))" + ] + final @?= ioOkResult (ofBytes (BS.pack [112, 111, 110, 103])) + + , testCase "recv on closed peer returns connection closed" $ + withFreePort $ \port -> do + final <- runIOSource $ + unlines + [ "preserveResult = (result okCase :" + , " matchResult" + , " (err rest : pure result)" + , " okCase" + , " result)" + , "" + , "clientTask = port :" + , " bind socket (result :" + , " preserveResult result (sock rest :" + , " bind (connect sock \"127.0.0.1\" port) (connectResult :" + , " preserveResult connectResult (_ rest :" + , " closeSocket sock))))" + , "" + , "main = io (" + , " bind socket (result :" + , " preserveResult result (server rest :" + , " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :" + , " preserveResult bindResult (_ rest :" + , " bind (listen server 1) (listenResult :" + , " preserveResult listenResult (_ rest :" + , " bind (fork (clientTask " ++ show port ++ ")) (_ :" + , " bind (accept server) (acceptResult :" + , " preserveResult acceptResult (accepted rest :" + , " matchPair" + , " (clientSock addr :" + , " bind (yield) (_ :" + , " recv clientSock 1))" + , " accepted))))))))))" + ] + final @?= ioErrResult "connection closed" + + , testCase "accept invalid socket handle returns error" $ do + final <- runIOSource "main = io (accept (pair \"sock\" 99999))" + final @?= ioErrResult "invalid socket handle" + + , testCase "recv invalid socket handle returns error" $ do + final <- runIOSource "main = io (recv (pair \"sock\" 99999) 1)" + final @?= ioErrResult "invalid socket handle" + + , testCase "send invalid socket handle returns error" $ do + final <- runIOSource "main = io (send (pair \"sock\" 99999) [(1)])" + final @?= ioErrResult "invalid socket handle" + + , testCase "getSocketName returns positive port after bind 0" $ do + final <- runIOSource $ + unlines + [ "preserveResult = (result okCase :" + , " matchResult" + , " (err rest : pure result)" + , " okCase" + , " result)" + , "" + , "main = io (" + , " bind socket (result :" + , " preserveResult result (server rest :" + , " bind (bindSocket server \"127.0.0.1\" 0) (bindResult :" + , " preserveResult bindResult (_ rest :" + , " getSocketName server)))))" + ] + case final of + Fork (Stem Leaf) (Fork val Leaf) -> + case toNumber val of + Right port | port > 0 -> return () + Right 0 -> assertFailure "Expected positive port, got 0" + Left _ -> assertFailure $ "Expected numeric port, got: " ++ show val + other -> assertFailure $ "Expected ok result, got: " ++ show other + ] ] +withFreePort :: (Int -> IO a) -> IO a +withFreePort action = + bracket + (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) + NS.close + (\s -> do + NS.setSocketOption s NS.ReuseAddr 1 + NS.bind s (NS.SockAddrInet 0 (NS.tupleToHostAddress (127, 0, 0, 1))) + port <- NS.socketPort s + action (fromIntegral port)) + runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T) runIOSourceWith perms readerEnv initialState source = do ioEnv <- evaluateFile "./lib/io.tri" - evalEnv <- evalTricuWithStore Nothing ioEnv (parseTricu source) + sockEnv <- evaluateFile "./lib/socket.tri" + let combinedEnv = Map.union sockEnv ioEnv + evalEnv <- evalTricuWithStore Nothing combinedEnv (parseTricu source) let fullTree = mainResult evalEnv result <- runIOWith perms readerEnv initialState fullTree case result of diff --git a/tricu.cabal b/tricu.cabal index e110109..4c4a334 100644 --- a/tricu.cabal +++ b/tricu.cabal @@ -52,6 +52,7 @@ executable tricu , megaparsec , memory , mtl + , network , servant , sqlite-simple , stm @@ -102,6 +103,7 @@ benchmark tricu-bench , megaparsec , memory , mtl + , network , sqlite-simple , text , time @@ -148,6 +150,7 @@ test-suite tricu-tests , megaparsec , memory , mtl + , network , servant , sqlite-simple , stm