2 Commits

Author SHA1 Message Date
020fa769a9 Event loop! 2026-05-19 17:00:36 -05:00
2e13583de3 Strings for IO driver errors 2026-05-18 19:12:42 -05:00
9 changed files with 691 additions and 113 deletions

View File

@@ -26,7 +26,7 @@
-- --
-- File operations return a Result tree (see lib/base.tri): -- File operations return a Result tree (see lib/base.tri):
-- ok value -- pair true (pair value t) -- ok value -- pair true (pair value t)
-- err code -- pair false (pair code t) -- err msg -- pair false (pair msg t)
-- --
-- Use onReadFile / onWriteFile for convenient branching. -- Use onReadFile / onWriteFile for convenient branching.
-- --

View File

@@ -0,0 +1,46 @@
!import "../../lib/base.tri" !Local
!import "../../lib/io.tri" !Local
!import "../../lib/socket.tri" !Local
-- Preserve the host-driver Result shape on error, run okCase on success.
onOk = action okCase :
bind action (result :
matchResult
(err rest : pure result)
okCase
result)
-- Convenience: print a string and continue.
printLn = s : bind (putStr (append s "\n")) (_ : pure t)
-- Main accept+echo loop. Recursion via y.
echoLoop = y (self server :
bind (accept server) (acceptResult :
matchResult
(err rest :
bind (printLn (append "accept error: " err)) (_ :
self server))
(accepted rest :
matchPair
(clientSock addr :
bind (printLn (append "client from " addr)) (_ :
bind (recv clientSock 4096) (msgResult :
matchResult
(err rest :
bind (closeSocket clientSock) (_ :
self server))
(msg rest :
bind (send clientSock msg) (_ :
bind (closeSocket clientSock) (_ :
self server)))
msgResult)))
accepted)
acceptResult))
main = io (
onOk socket (server rest :
onOk (bindSocket server "127.0.0.1" 0) (_ rest :
onOk (listen server 5) (_ rest :
onOk (getSocketName server) (port rest :
bind (printLn (append "Echo server listening on port " (showNumber port))) (_ :
echoLoop server))))))

View File

@@ -74,7 +74,7 @@ succ = y (self :
t)) t))
ok = value rest : pair true (pair value rest) ok = value rest : pair true (pair value rest)
err = code rest : pair false (pair code rest) err = msg rest : pair false (pair msg rest)
matchResult = (errCase okCase result : matchResult = (errCase okCase result :
matchPair matchPair

View File

@@ -63,24 +63,24 @@ onWriteFile = (path contents errCase okCase :
readFileOrPrintError = (path okCase : readFileOrPrintError = (path okCase :
onReadFile path onReadFile path
(err rest : putStrLn "Read failed") (err rest : putStrLn (append "Read failed: " err))
okCase) okCase)
writeFileOrPrintError = (path contents okCase : writeFileOrPrintError = (path contents okCase :
onWriteFile path contents onWriteFile path contents
(err rest : putStrLn "Write failed") (err rest : putStrLn (append "Write failed: " err))
okCase) okCase)
copyFile = (src dst : copyFile = (src dst :
bind (readFile src) bind (readFile src)
(result : (result :
matchResult matchResult
(err rest : putStrLn "Read failed") (err rest : putStrLn (append "Read failed: " err))
(contents rest : (contents rest :
bind (writeFile dst contents) bind (writeFile dst contents)
(wr : (wr :
matchResult matchResult
(err rest : putStrLn "Write failed") (err rest : putStrLn (append "Write failed: " err))
(ok rest : pure t) (ok rest : pure t)
wr)) wr))
result)) result))

63
lib/socket.tri Normal file
View File

@@ -0,0 +1,63 @@
!import "base.tri" !Local
!import "io.tri" !Local
-- Socket primitives for the IO driver.
-- All actions return a Result tree (see lib/base.tri):
-- ok value -- pair true (pair value t)
-- err msg -- pair false (pair msg t)
socket = pair 70 t
closeSocket = sock : pair 71 sock
bindSocket = sock addr port : pair 72 (pair sock (pair addr port))
listen = sock backlog : pair 73 (pair sock backlog)
accept = sock : pair 74 sock
connect = sock addr port : pair 75 (pair sock (pair addr port))
recv = sock maxBytes : pair 76 (pair sock maxBytes)
send = sock bytes : pair 77 (pair sock bytes)
getSocketName = sock : pair 78 sock
-- ---------------------------------------------------------------------------
-- Convenience helpers
-- ---------------------------------------------------------------------------
onSocket = (action errCase okCase :
bind action (result :
matchResult errCase okCase result))
-- Create a listening socket bound to an address and port.
-- Returns ok listenSocket or err message.
listenSocket = addr port backlog :
bind (socket) (result :
matchResult
(err rest : pure (err "socket creation failed"))
(sock rest :
bind (bindSocket sock addr port) (bindResult :
matchResult
(err rest : pure (err "bind failed"))
(_ rest :
bind (listen sock backlog) (listenResult :
matchResult
(err rest : pure (err "listen failed"))
(_ rest : pure (ok sock))
listenResult))
bindResult))
result)
-- Accept a connection and return (clientSocket, peerAddr).
-- The returned peerAddr is a string like "127.0.0.1:8080".
onAccept = (sock errCase okCase :
bind (accept sock) (result :
matchResult errCase okCase result))
-- Receive all available bytes up to maxBytes.
onRecv = (sock maxBytes errCase okCase :
bind (recv sock maxBytes) (result :
matchResult errCase okCase result))
-- Send bytes and return number of bytes sent.
onSend = (sock bytes errCase okCase :
bind (send sock bytes) (result :
matchResult errCase okCase result))
-- Close a socket, ignoring errors.
closeSocket_ = sock : bind (closeSocket sock) (_ : pure t)

View File

@@ -33,16 +33,15 @@ type Uses = [Bool]
evalSingle :: Env -> TricuAST -> Env evalSingle :: Env -> TricuAST -> Env
evalSingle env term evalSingle env term
| SDef name [] body <- term | SDef name params body <- term
= case Map.lookup name env of = let res = evalASTSync env (if null params then body else SLambda params body)
Just existingValue in case Map.lookup name env of
| existingValue == evalASTSync env body -> env Just existingValue
| otherwise | existingValue == res -> env
-> let res = evalASTSync env body | otherwise
in Map.insert "!result" res (Map.insert name res env) -> Map.insert "!result" res (Map.insert name res env)
Nothing Nothing
-> let res = evalASTSync env body -> Map.insert "!result" res (Map.insert name res env)
in Map.insert "!result" res (Map.insert name res env)
| SApp func arg <- term | SApp func arg <- term
= let res = apply (evalASTSync env func) (evalASTSync env arg) = let res = apply (evalASTSync env func) (evalASTSync env arg)
in Map.insert "!result" res env in Map.insert "!result" res env
@@ -57,7 +56,7 @@ evalSingle env term
in Map.insert "!result" res env in Map.insert "!result" res env
evalTricu :: Env -> [TricuAST] -> Env evalTricu :: Env -> [TricuAST] -> Env
evalTricu env x = go env (reorderDefs env x) evalTricu env x = go env (reorderDefs env (map recoverParams x))
where where
go env' [] = env' go env' [] = env'
go env' [def] = go env' [def] =
@@ -102,11 +101,10 @@ evalASTWithEnv mconn localEnv ast = do
let combinedEnv = Map.union localEnv storeEnv let combinedEnv = Map.union localEnv storeEnv
return $ evalASTSync combinedEnv ast return $ evalASTSync combinedEnv ast
-- | Store-aware version of 'evalSingle'.
evalSingleWithStore :: Maybe Connection -> Env -> TricuAST -> IO Env evalSingleWithStore :: Maybe Connection -> Env -> TricuAST -> IO Env
evalSingleWithStore mconn env term evalSingleWithStore mconn env term
| SDef name [] body <- term = do | SDef name params body <- term = do
res <- evalASTWithEnv mconn env body res <- evalASTWithEnv mconn env (if null params then body else SLambda params body)
case Map.lookup name env of case Map.lookup name env of
Just existingValue Just existingValue
| existingValue == res -> return env | existingValue == res -> return env
@@ -116,11 +114,8 @@ evalSingleWithStore mconn env term
res <- evalASTWithEnv mconn env term res <- evalASTWithEnv mconn env term
return $ Map.insert "!result" res env return $ Map.insert "!result" res env
-- | Store-aware version of 'evalTricu'. Does not preload the entire
-- content store; terms are resolved on demand as variables are
-- encountered.
evalTricuWithStore :: Maybe Connection -> Env -> [TricuAST] -> IO Env evalTricuWithStore :: Maybe Connection -> Env -> [TricuAST] -> IO Env
evalTricuWithStore mconn env x = go env (reorderDefs env x) evalTricuWithStore mconn env x = go env (reorderDefs env (map recoverParams x))
where where
go env' [] = return env' go env' [] = return env'
go env' [def] = do go env' [def] = do
@@ -130,6 +125,10 @@ evalTricuWithStore mconn env x = go env (reorderDefs env x)
updatedEnv <- evalSingleWithStore mconn env' def updatedEnv <- evalSingleWithStore mconn env' def
evalTricuWithStore mconn updatedEnv xs evalTricuWithStore mconn updatedEnv xs
recoverParams :: TricuAST -> TricuAST
recoverParams (SDef name [] (SLambda params body)) = SDef name params body
recoverParams term = term
collectVarNames :: TricuAST -> [(String, Maybe String)] collectVarNames :: TricuAST -> [(String, Maybe String)]
collectVarNames = go [] collectVarNames = go []
where where
@@ -189,6 +188,7 @@ elimLambda = go
| isSList term = slistTransform term | isSList term = slistTransform term
| otherwise = term | otherwise = term
etaReduction (SLambda [v] (SVar x Nothing)) = v == x
etaReduction (SLambda [v] (SApp f (SVar x Nothing))) = v == x && not (usesBinder v f) etaReduction (SLambda [v] (SApp f (SVar x Nothing))) = v == x && not (usesBinder v f)
etaReduction _ = False etaReduction _ = False
@@ -209,8 +209,9 @@ elimLambda = go
application (SApp _ _) = True application (SApp _ _) = True
application _ = False application _ = False
etaReduceResult (SLambda [_] (SVar _ Nothing)) = _I
etaReduceResult (SLambda [_] (SApp f _)) = f etaReduceResult (SLambda [_] (SApp f _)) = f
etaReduceResult _ = error "etaReduceResult: expected SLambda [v] (SApp f _)" etaReduceResult _ = error "etaReduceResult: unexpected shape"
lambdaListResult (SLambda [v] (SList xs)) = lambdaListResult (SLambda [v] (SList xs)) =
SLambda [v] (foldr wrapTLeaf TLeaf xs) SLambda [v] (foldr wrapTLeaf TLeaf xs)
@@ -254,12 +255,12 @@ composeBody f g x = SApp (SVar f Nothing) (SApp (SVar g Nothing) (SVar x Nothin
isFree :: String -> TricuAST -> Bool isFree :: String -> TricuAST -> Bool
isFree x t = Set.member x (freeVars t) isFree x t = Set.member x (freeVars t)
-- Keep old freeVars for compatibility with reorderDefs which still uses TricuAST
freeVars :: TricuAST -> Set String freeVars :: TricuAST -> Set String
freeVars (SVar v Nothing) = Set.singleton v freeVars (SVar v Nothing) = Set.singleton v
freeVars (SVar v (Just _)) = Set.singleton v freeVars (SVar v (Just _)) = Set.singleton v
freeVars (SApp t u) = Set.union (freeVars t) (freeVars u) freeVars (SApp t u) = Set.union (freeVars t) (freeVars u)
freeVars (SLambda vs body) = Set.difference (freeVars body) (Set.fromList vs) freeVars (SLambda vs body) = Set.difference (freeVars body) (Set.fromList vs)
freeVars (SDef _ params body) = Set.difference (freeVars body) (Set.fromList params)
freeVars (TStem t) = freeVars t freeVars (TStem t) = freeVars t
freeVars (TFork t u) = Set.union (freeVars t) (freeVars u) freeVars (TFork t u) = Set.union (freeVars t) (freeVars u)
freeVars (SList xs) = foldMap freeVars xs freeVars (SList xs) = foldMap freeVars xs
@@ -275,7 +276,7 @@ reorderDefs env defs
(defsOnly, others) = partition isDef defs (defsOnly, others) = partition isDef defs
defNames = [ name | SDef name _ _ <- defsOnly ] defNames = [ name | SDef name _ _ <- defsOnly ]
defsWithFreeVars = [(def, freeVars body) | def@(SDef _ _ body) <- defsOnly] defsWithFreeVars = [(def, freeVars def) | def <- defsOnly]
graph = buildDepGraph defsOnly graph = buildDepGraph defsOnly
sortedDefs = sortDeps graph sortedDefs = sortDeps graph
@@ -298,8 +299,8 @@ buildDepGraph topDefs
"Conflicting definitions detected: " ++ show conflictingDefs "Conflicting definitions detected: " ++ show conflictingDefs
| otherwise = | otherwise =
Map.fromList Map.fromList
[ (name, depends topDefs (SDef name [] body)) [ (name, depends topDefs def)
| SDef name _ body <- topDefs] | def@(SDef name _ _) <- topDefs]
where where
defsMap = Map.fromListWith (++) defsMap = Map.fromListWith (++)
[(name, [(name, body)]) | SDef name _ body <- topDefs] [(name, [(name, body)]) | SDef name _ body <- topDefs]
@@ -329,10 +330,10 @@ sortDeps graph = go [] Set.empty (Map.keys graph)
notReady notReady
depends :: [TricuAST] -> TricuAST -> Set.Set String depends :: [TricuAST] -> TricuAST -> Set.Set String
depends topDefs (SDef _ _ body) = depends topDefs def@(SDef _ _ _) =
Set.intersection Set.intersection
(Set.fromList [n | SDef n _ _ <- topDefs]) (Set.fromList [n | SDef n _ _ <- topDefs])
(freeVars body) (freeVars def)
depends _ _ = Set.empty depends _ _ = Set.empty
result :: Env -> T result :: Env -> T

View File

@@ -12,7 +12,7 @@ import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes,
import qualified Data.ByteString as BS import qualified Data.ByteString as BS
import System.IO (putStr, getLine) import System.IO (putStr, getLine)
import qualified System.IO as IO import qualified System.IO as IO
import Control.Exception (try, IOException, SomeException) import Control.Exception (try, catch, IOException, SomeException)
import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError) import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError)
import Data.List (isPrefixOf) import Data.List (isPrefixOf)
import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories) import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories)
@@ -27,6 +27,8 @@ import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar,
import qualified Data.Set as Set import qualified Data.Set as Set
import Data.Set (Set) import Data.Set (Set)
import qualified Data.Foldable as Fold import qualified Data.Foldable as Fold
import qualified Network.Socket as NS
import qualified Network.Socket.ByteString as NSB
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
-- Permissions -- Permissions
@@ -95,67 +97,36 @@ data Machine = Machine
-- --
-- Runtime protocol errors are returned as direct values via errResult. -- Runtime protocol errors are returned as direct values via errResult.
-- Error code ranges:
-- 1-19 host IO / filesystem errors
-- 20-39 policy / permission errors
-- 40-59 protocol / decode / type errors
-- 60-79 async errors
-- 80-99 scheduler / runtime errors
-- Host IO / filesystem errors (1-19)
errDoesNotExist, errPermission, errAlreadyExists, errIOOther :: Integer
errDoesNotExist = 1
errPermission = 2
errAlreadyExists = 3
errIOOther = 4
-- Policy / permission errors (20-39)
errPolicyDeny :: Integer
errPolicyDeny = 20
-- Protocol / decode / type errors (40-59)
errInvalidAction, errInvalidString :: Integer
errInvalidAction = 40
errInvalidString = 41
-- Async errors (60-79)
errInvalidHandle, errSelfAwait, errInvalidSleep, errCyclicAwait :: Integer
errInvalidHandle = 60
errSelfAwait = 61
errInvalidSleep = 62
errCyclicAwait = 63
-- Scheduler / runtime errors (80-99)
errDeadlock :: Integer
errDeadlock = 80
ioErrorCode :: IOException -> Integer
ioErrorCode e
| isDoesNotExistError e = errDoesNotExist
| isPermissionError e = errPermission
| isAlreadyExistsError e = errAlreadyExists
| otherwise = errIOOther
okResult :: T -> T okResult :: T -> T
okResult val = Fork (Stem Leaf) (Fork val Leaf) okResult val = Fork (Stem Leaf) (Fork val Leaf)
errResult :: Integer -> T errResult :: String -> T
errResult code = Fork Leaf (Fork (ofNumber code) Leaf) errResult msg = Fork Leaf (Fork (ofString msg) Leaf)
pureAction :: T -> T pureAction :: T -> T
pureAction x = Fork (ofNumber 0) x pureAction x = Fork (ofNumber 0) x
invalidAsyncHandleResult :: T invalidAsyncHandleResult :: T
invalidAsyncHandleResult = errResult errInvalidHandle invalidAsyncHandleResult = errResult "invalid task handle"
invalidSocketHandleResult :: T
invalidSocketHandleResult = errResult "invalid socket handle"
selfAwaitResult :: T selfAwaitResult :: T
selfAwaitResult = errResult errSelfAwait selfAwaitResult = errResult "self await"
deadlockResult :: T deadlockResult :: T
deadlockResult = errResult errDeadlock deadlockResult = errResult "deadlock"
invalidSleepResult :: T invalidSleepResult :: T
invalidSleepResult = errResult errInvalidSleep invalidSleepResult = errResult "invalid sleep"
ioErrorString :: IOException -> String
ioErrorString e
| isDoesNotExistError e = "does not exist"
| isPermissionError e = "permission denied"
| isAlreadyExistsError e = "already exists"
| otherwise = "io error"
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
-- Task identity and handles -- Task identity and handles
@@ -179,6 +150,45 @@ decodeTaskHandle tree =
_ -> _ ->
Left "invalid task handle" Left "invalid task handle"
-- ---------------------------------------------------------------------------
-- Socket identity and handles
-- ---------------------------------------------------------------------------
newtype SockId = SockId Integer
deriving (Eq, Ord, Show)
sockHandle :: SockId -> T
sockHandle (SockId n) =
Fork (ofString "sock") (ofNumber n)
decodeSockHandle :: T -> Either String SockId
decodeSockHandle tree =
case tree of
Fork tag nTree -> do
tagString <- toString tag
if tagString == "sock"
then SockId <$> toNumber nTree
else Left "invalid socket handle tag"
_ ->
Left "invalid socket handle"
getSocketPort :: NS.Socket -> IO (Maybe Integer)
getSocketPort sock = do
addr <- NS.getSocketName sock
case addr of
NS.SockAddrInet p _ -> return (Just (fromIntegral p))
NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p))
_ -> return Nothing
-- ---------------------------------------------------------------------------
-- Socket registry
-- ---------------------------------------------------------------------------
data SocketRegistry = SocketRegistry
{ sockMap :: Map SockId NS.Socket
, sockNextId :: Integer
}
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
-- Free-monad action AST -- Free-monad action AST
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
@@ -200,6 +210,15 @@ data Action
| AAwait T | AAwait T
| AYield | AYield
| ASleep T | ASleep T
| ASocket
| ACloseSocket T
| ABindSocket T T T
| AListen T T
| AAccept T
| AConnect T T T
| ARecv T T
| ASend T T
| AGetSocketName T
deriving (Show) deriving (Show)
-- --------------------------------------------------------------------------- -- ---------------------------------------------------------------------------
@@ -234,6 +253,19 @@ tagAwait = 61
tagYield = 62 tagYield = 62
tagSleep = 63 tagSleep = 63
tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer
tagSocket = 70
tagCloseSocket = 71
tagBindSocket = 72
tagListen = 73
tagAccept = 74
tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer
tagConnect = 75
tagRecv = 76
tagSend = 77
tagGetSocketName = 78
data Step data Step
= Halt Runtime T = Halt Runtime T
| Continue Machine | Continue Machine
@@ -313,6 +345,43 @@ decodeAction tree =
Right n | n == tagSleep -> Right n | n == tagSleep ->
Right (ASleep payload) Right (ASleep payload)
Right n | n == tagSocket ->
Right ASocket
Right n | n == tagCloseSocket ->
Right (ACloseSocket payload)
Right n | n == tagBindSocket ->
case payload of
Fork sock (Fork addr port) -> Right (ABindSocket sock addr port)
_ -> Left "Invalid BindSocket: expected pair sock (pair addr port)"
Right n | n == tagListen ->
case payload of
Fork sock backlog -> Right (AListen sock backlog)
_ -> Left "Invalid Listen: expected pair sock backlog"
Right n | n == tagAccept ->
Right (AAccept payload)
Right n | n == tagConnect ->
case payload of
Fork sock (Fork addr port) -> Right (AConnect sock addr port)
_ -> Left "Invalid Connect: expected pair sock (pair addr port)"
Right n | n == tagRecv ->
case payload of
Fork sock maxBytes -> Right (ARecv sock maxBytes)
_ -> Left "Invalid Recv: expected pair sock maxBytes"
Right n | n == tagSend ->
case payload of
Fork sock bytes -> Right (ASend sock bytes)
_ -> Left "Invalid Send: expected pair sock bytes"
Right n | n == tagGetSocketName ->
Right (AGetSocketName payload)
Right n -> Right n ->
Left $ "Unknown IO action tag: " ++ show n Left $ "Unknown IO action tag: " ++ show n
@@ -346,11 +415,11 @@ finishValue machine value =
, machineFrames = rest , machineFrames = rest
}) })
stepMachine :: Machine -> IO Step stepMachine :: TVar SocketRegistry -> Machine -> IO Step
stepMachine machine = stepMachine sockVar machine =
case decodeAction (machineCurrent machine) of case decodeAction (machineCurrent machine) of
Right action -> dispatch action Right action -> dispatch action
Left _ -> finishValue machine (errResult errInvalidAction) Left _ -> finishValue machine (errResult "invalid action")
where where
dispatch action = case action of dispatch action = case action of
APure val -> APure val ->
@@ -367,14 +436,14 @@ stepMachine machine =
Right s -> Right s ->
pure (AsyncAction (putStr s >> pure Leaf) machine) pure (AsyncAction (putStr s >> pure Leaf) machine)
Left _ -> Left _ ->
finishValue machine (errResult errInvalidString) finishValue machine (errResult "invalid string")
APutBytes bs -> APutBytes bs ->
case decodeBytes bs "PutBytes" of case decodeBytes bs "PutBytes" of
Right b -> Right b ->
pure (AsyncAction (BS.putStr b >> pure Leaf) machine) pure (AsyncAction (BS.putStr b >> pure Leaf) machine)
Left _ -> Left _ ->
finishValue machine (errResult errInvalidString) finishValue machine (errResult "invalid bytes")
AGetLine -> AGetLine ->
pure (AsyncAction (ofString <$> getLine) machine) pure (AsyncAction (ofString <$> getLine) machine)
@@ -386,7 +455,7 @@ stepMachine machine =
case mDeny of case mDeny of
Just denied -> finishValue machine denied Just denied -> finishValue machine denied
Nothing -> pure (AsyncAction (tryReadFile p) machine) Nothing -> pure (AsyncAction (tryReadFile p) machine)
Left _ -> finishValue machine (errResult errInvalidString) Left _ -> finishValue machine (errResult "invalid string")
AWriteFile path contents -> AWriteFile path contents ->
case decodeString path "WriteFile" of case decodeString path "WriteFile" of
@@ -397,8 +466,8 @@ stepMachine machine =
case mDeny of case mDeny of
Just denied -> finishValue machine denied Just denied -> finishValue machine denied
Nothing -> pure (AsyncAction (tryWriteFile p c) machine) Nothing -> pure (AsyncAction (tryWriteFile p c) machine)
Left _ -> finishValue machine (errResult errInvalidString) Left _ -> finishValue machine (errResult "invalid string")
Left _ -> finishValue machine (errResult errInvalidString) Left _ -> finishValue machine (errResult "invalid string")
AWriteBytes path contents -> AWriteBytes path contents ->
case decodeString path "WriteBytes" of case decodeString path "WriteBytes" of
@@ -409,8 +478,8 @@ stepMachine machine =
case mDeny of case mDeny of
Just denied -> finishValue machine denied Just denied -> finishValue machine denied
Nothing -> pure (AsyncAction (tryWriteFileBytes p c) machine) Nothing -> pure (AsyncAction (tryWriteFileBytes p c) machine)
Left _ -> finishValue machine (errResult errInvalidString) Left _ -> finishValue machine (errResult "invalid bytes")
Left _ -> finishValue machine (errResult errInvalidString) Left _ -> finishValue machine (errResult "invalid string")
AAsk -> AAsk ->
finishValue machine (rtEnv (machineRuntime machine)) finishValue machine (rtEnv (machineRuntime machine))
@@ -453,6 +522,207 @@ stepMachine machine =
_ -> _ ->
finishValue machine invalidSleepResult finishValue machine invalidSleepResult
ASocket -> do
result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket)
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right sock -> do
NS.setSocketOption sock NS.ReuseAddr 1
sid <- atomically $ do
SocketRegistry m next <- readTVar sockVar
let sid = SockId next
writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1))
return sid
finishValue machine (okResult (sockHandle sid))
ACloseSocket sockTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid -> do
mSock <- atomically $ do
SocketRegistry m next <- readTVar sockVar
case Map.lookup sid m of
Nothing -> return Nothing
Just sock -> do
writeTVar sockVar (SocketRegistry (Map.delete sid m) next)
return (Just sock)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
NS.close sock
finishValue machine (okResult Leaf)
ABindSocket sockTree addrTree portTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeString addrTree "BindSocket" of
Left _ -> finishValue machine (errResult "invalid address")
Right addrStr ->
case toNumber portTree of
Left _ -> finishValue machine (errResult "invalid port")
Right port -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
result <- try (do
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
(Just addrStr)
(Just (show port))
let serverAddr = head addrInfo
NS.bind sock (NS.addrAddress serverAddr)
) :: IO (Either SomeException ())
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right () ->
finishValue machine (okResult Leaf)
AListen sockTree backlogTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case toNumber backlogTree of
Left _ -> finishValue machine (errResult "invalid backlog")
Right backlog -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ())
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right () ->
finishValue machine (okResult Leaf)
AAccept listenTree ->
case decodeSockHandle listenTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right listenSid ->
pure (AsyncAction (do
mListenSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup listenSid m)
case mListenSock of
Nothing -> return (errResult "invalid socket handle")
Just listenSock -> do
result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr))
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right (clientSock, addr) -> do
clientSid <- atomically $ do
SocketRegistry m next <- readTVar sockVar
let sid = SockId next
writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1))
return sid
let addrStr = case addr of
NS.SockAddrInet p h ->
let (a,b,c,d) = NS.hostAddressToTuple h
in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p
_ -> show addr
return (okResult (Fork (sockHandle clientSid) (ofString addrStr)))
) machine)
AConnect sockTree addrTree portTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeString addrTree "Connect" of
Left _ -> finishValue machine (errResult "invalid address")
Right addrStr ->
case toNumber portTree of
Left _ -> finishValue machine (errResult "invalid port")
Right port -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (do
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
(Just addrStr)
(Just (show port))
let serverAddr = head addrInfo
NS.connect sock (NS.addrAddress serverAddr)
) :: IO (Either SomeException ())
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right () ->
return (okResult Leaf)
) machine)
ARecv sockTree maxBytesTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case toNumber maxBytesTree of
Left _ -> finishValue machine (errResult "invalid maxBytes")
Right maxBytes -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString)
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right bs ->
if BS.null bs
then return (errResult "connection closed")
else return (okResult (ofBytes bs))
) machine)
AGetSocketName sockTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
mPort <- getSocketPort sock
case mPort of
Just port -> finishValue machine (okResult (ofNumber port))
Nothing -> finishValue machine (errResult "io error: could not get socket name")
ASend sockTree bytesTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeBytes bytesTree "Send" of
Left _ -> finishValue machine (errResult "invalid bytes")
Right bs -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (NSB.send sock bs) :: IO (Either SomeException Int)
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right sent ->
return (okResult (ofNumber (fromIntegral sent)))
) machine)
-- Permission and IO helpers -- Permission and IO helpers
checkReadPerm p = checkReadPerm p =
if allowReadAll (rtPerms (machineRuntime machine)) if allowReadAll (rtPerms (machineRuntime machine))
@@ -480,7 +750,7 @@ stepMachine machine =
then return Nothing then return Nothing
else return $ Just policyErrResult else return $ Just policyErrResult
policyErrResult = errResult errPolicyDeny policyErrResult = errResult "permission denied"
canonicalizeSafe :: FilePath -> IO (Either String FilePath) canonicalizeSafe :: FilePath -> IO (Either String FilePath)
canonicalizeSafe p = do canonicalizeSafe p = do
@@ -534,19 +804,19 @@ stepMachine machine =
result <- try (BS.readFile path) :: IO (Either IOException BS.ByteString) result <- try (BS.readFile path) :: IO (Either IOException BS.ByteString)
case result of case result of
Right content -> return $ okResult (ofBytes content) Right content -> return $ okResult (ofBytes content)
Left e -> return $ errResult (ioErrorCode e) Left e -> return $ errResult (ioErrorString e)
tryWriteFile path contents = do tryWriteFile path contents = do
result <- try (IO.writeFile path contents) :: IO (Either IOException ()) result <- try (IO.writeFile path contents) :: IO (Either IOException ())
case result of case result of
Right () -> return $ okResult Leaf Right () -> return $ okResult Leaf
Left e -> return $ errResult (ioErrorCode e) Left e -> return $ errResult (ioErrorString e)
tryWriteFileBytes path contents = do tryWriteFileBytes path contents = do
result <- try (BS.writeFile path contents) :: IO (Either IOException ()) result <- try (BS.writeFile path contents) :: IO (Either IOException ())
case result of case result of
Right () -> return $ okResult Leaf Right () -> return $ okResult Leaf
Left e -> return $ errResult (ioErrorCode e) Left e -> return $ errResult (ioErrorString e)
decodeString t ctx = decodeString t ctx =
case toString t of case toString t of
@@ -577,6 +847,8 @@ data Scheduler = Scheduler
, schedulerSleepQueue :: Map UTCTime (Set TaskId) , schedulerSleepQueue :: Map UTCTime (Set TaskId)
, schedulerAsyncCompleted :: TVar (Map TaskId T) , schedulerAsyncCompleted :: TVar (Map TaskId T)
, schedulerCompleted :: Map TaskId (T, T) , schedulerCompleted :: Map TaskId (T, T)
, schedulerSockets :: TVar SocketRegistry
, schedulerNextSockId :: Integer
} }
instance Show Scheduler where instance Show Scheduler where
@@ -587,10 +859,12 @@ instance Show Scheduler where
++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s) ++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s)
++ ", schedulerAsyncCompleted = <tvar>" ++ ", schedulerAsyncCompleted = <tvar>"
++ ", schedulerCompleted = " ++ show (schedulerCompleted s) ++ ", schedulerCompleted = " ++ show (schedulerCompleted s)
++ ", schedulerSockets = <tvar>"
++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s)
++ " }" ++ " }"
initialScheduler :: TVar (Map TaskId T) -> Machine -> Scheduler initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler
initialScheduler asyncVar mainMachine = initialScheduler asyncVar sockVar mainMachine =
Scheduler Scheduler
{ schedulerNextTaskId = 1 { schedulerNextTaskId = 1
, schedulerRunnable = Seq.singleton (TaskId 0) , schedulerRunnable = Seq.singleton (TaskId 0)
@@ -599,6 +873,8 @@ initialScheduler asyncVar mainMachine =
, schedulerSleepQueue = Map.empty , schedulerSleepQueue = Map.empty
, schedulerAsyncCompleted = asyncVar , schedulerAsyncCompleted = asyncVar
, schedulerCompleted = Map.empty , schedulerCompleted = Map.empty
, schedulerSockets = sockVar
, schedulerNextSockId = 0
} }
runtimeOfStatus :: TaskStatus -> Maybe Runtime runtimeOfStatus :: TaskStatus -> Maybe Runtime
@@ -732,7 +1008,7 @@ handleStep currentId (AwaitRequested targetId machine) scheduler
Just (BlockedOn nextId _) -> Just (BlockedOn nextId _) ->
if wouldCycle targetId currentId (schedulerTasks scheduler) if wouldCycle targetId currentId (schedulerTasks scheduler)
then resumeCurrentWith currentId (errResult errCyclicAwait) machine scheduler then resumeCurrentWith currentId (errResult "cyclic await") machine scheduler
else block else block
Just _ -> block Just _ -> block
@@ -759,8 +1035,11 @@ handleStep taskId (SleepRequested ms machine) scheduler = do
handleStep taskId (AsyncAction ioAction machine) scheduler = do handleStep taskId (AsyncAction ioAction machine) scheduler = do
_ <- forkIO $ do _ <- forkIO $ do
result <- ioAction result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e))
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId result) atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $
case result of
Right val -> val
Left msg -> errResult msg)
pure scheduler pure scheduler
{ schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler) { schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler)
} }
@@ -818,7 +1097,7 @@ schedulerStep scheduler = do
taskId :< restQueue -> taskId :< restQueue ->
case Map.lookup taskId (schedulerTasks scheduler1) of case Map.lookup taskId (schedulerTasks scheduler1) of
Just (Runnable machine) -> do Just (Runnable machine) -> do
step <- stepMachine machine step <- stepMachine (schedulerSockets scheduler1) machine
handleStep taskId step scheduler1 { schedulerRunnable = restQueue } handleStep taskId step scheduler1 { schedulerRunnable = restQueue }
_ -> _ ->
@@ -843,6 +1122,7 @@ runIOWith perms env initialState action =
Left err -> pure (Left err) Left err -> pure (Left err)
Right (_, action') -> do Right (_, action') -> do
asyncVar <- newTVarIO Map.empty asyncVar <- newTVarIO Map.empty
sockVar <- newTVarIO (SocketRegistry Map.empty 0)
let initialMachine = Machine let initialMachine = Machine
{ machineRuntime = Runtime { machineRuntime = Runtime
{ rtPerms = perms { rtPerms = perms
@@ -852,7 +1132,7 @@ runIOWith perms env initialState action =
, machineCurrent = action' , machineCurrent = action'
, machineFrames = [] , machineFrames = []
} }
Right <$> runScheduler (initialScheduler asyncVar initialMachine) Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine)
runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T) runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T)
runIOWithEnv perms env action = do runIOWithEnv perms env action = do

View File

@@ -10,7 +10,8 @@ import Wire
import ContentStore import ContentStore
import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms) import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms)
import Control.Exception (evaluate, try, SomeException) import Control.Exception (bracket, evaluate, try, SomeException)
import qualified Network.Socket as NS
import Control.Monad (forM_) import Control.Monad (forM_)
import Control.Monad.IO.Class (liftIO) import Control.Monad.IO.Class (liftIO)
import System.IO.Temp (withSystemTempDirectory) import System.IO.Temp (withSystemTempDirectory)
@@ -1553,15 +1554,15 @@ ioDriverTests = testGroup "IO driver tests"
-- Malformed action tests -- Malformed action tests
, testCase "unknown IO action tag returns err result" $ do , testCase "unknown IO action tag returns err result" $ do
final <- runIOSource "main = io (pair 99 t)" final <- runIOSource "main = io (pair 99 t)"
final @?= ioErrResult 40 final @?= ioErrResult "invalid action"
, testCase "malformed Bind returns err result" $ do , testCase "malformed Bind returns err result" $ do
final <- runIOSource "main = io (pair 1 t)" final <- runIOSource "main = io (pair 1 t)"
final @?= ioErrResult 40 final @?= ioErrResult "invalid action"
, testCase "malformed ReadFile payload returns err result" $ do , testCase "malformed ReadFile payload returns err result" $ do
final <- runIOSource "main = io (readFile (t t))" final <- runIOSource "main = io (readFile (t t))"
final @?= ioErrResult 41 final @?= ioErrResult "invalid string"
-- Permission tests -- Permission tests
, testCase "allowed read path succeeds" $ , testCase "allowed read path succeeds" $
@@ -1586,7 +1587,7 @@ ioDriverTests = testGroup "IO driver tests"
unlines unlines
[ "main = io (readFile \"" ++ deniedPath ++ "\")" [ "main = io (readFile \"" ++ deniedPath ++ "\")"
] ]
result @?= ioErrResult 20 result @?= ioErrResult "permission denied"
, testCase "writeFile denied path returns err result" $ , testCase "writeFile denied path returns err result" $
withSystemTempDirectory "tricu-io-write-denied" $ \dir -> do withSystemTempDirectory "tricu-io-write-denied" $ \dir -> do
@@ -1597,7 +1598,7 @@ ioDriverTests = testGroup "IO driver tests"
unlines unlines
[ "main = io (writeFile \"" ++ deniedPath ++ "\" \"x\")" [ "main = io (writeFile \"" ++ deniedPath ++ "\" \"x\")"
] ]
result @?= ioErrResult 20 result @?= ioErrResult "permission denied"
, testCase "path prefix does not allow prefix bypass" $ , testCase "path prefix does not allow prefix bypass" $
withSystemTempDirectory "tricu-io-prefix" $ \dir -> do withSystemTempDirectory "tricu-io-prefix" $ \dir -> do
@@ -1611,7 +1612,7 @@ ioDriverTests = testGroup "IO driver tests"
unlines unlines
[ "main = io (readFile \"" ++ bypassPath ++ "\")" [ "main = io (readFile \"" ++ bypassPath ++ "\")"
] ]
result @?= ioErrResult 20 result @?= ioErrResult "permission denied"
-- Pure test -- Pure test
, testCase "pure performs no effects" $ do , testCase "pure performs no effects" $ do
@@ -1820,14 +1821,14 @@ ioDriverTests = testGroup "IO driver tests"
unlines unlines
[ "main = io (await (pair \"task\" 0))" [ "main = io (await (pair \"task\" 0))"
] ]
final @?= ioErrResult 61 final @?= ioErrResult "self await"
, testCase "await invalid handle returns async error" $ do , testCase "await invalid handle returns async error" $ do
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf $ (final, _) <- runIOSourceWith unsafePerms Leaf Leaf $
unlines unlines
[ "main = io (await 123)" [ "main = io (await 123)"
] ]
final @?= ioErrResult 60 final @?= ioErrResult "invalid task handle"
, testCase "yield returns unit and resumes continuation" $ do , testCase "yield returns unit and resumes continuation" $ do
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf $ (final, _) <- runIOSourceWith unsafePerms Leaf Leaf $
@@ -1890,7 +1891,7 @@ ioDriverTests = testGroup "IO driver tests"
[ "main = io (bind (fork (await (pair \"task\" 0))) (h :" [ "main = io (bind (fork (await (pair \"task\" 0))) (h :"
, " await h))" , " await h))"
] ]
final @?= ioErrResult 63 final @?= ioErrResult "cyclic await"
, testCase "writeBytes and readFile roundtrip binary data" $ , testCase "writeBytes and readFile roundtrip binary data" $
withSystemTempDirectory "tricu-io-bytes" $ \dir -> do withSystemTempDirectory "tricu-io-bytes" $ \dir -> do
@@ -1918,12 +1919,196 @@ ioDriverTests = testGroup "IO driver tests"
build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))" build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))"
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")") (final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")")
final @?= ofString "done" final @?= ofString "done"
, testGroup "Socket primitives"
[ testCase "socket returns ok result with valid handle" $ do
final <- runIOSource "main = io socket"
final @?= ioOkResult (Fork (ofString "sock") (ofNumber 0))
, testCase "closeSocket on invalid handle returns error" $ do
final <- runIOSource "main = io (closeSocket (pair \"sock\" 99999))"
final @?= ioErrResult "invalid socket handle"
, testCase "bindSocket and listen succeed on loopback port 0" $ do
final <- runIOSource "main = io (bind socket (result : matchResult (err rest : pure result) (sock rest : bind (bindSocket sock \"127.0.0.1\" 0) (bindResult : matchResult (err rest : pure bindResult) (_ rest : bind (listen sock 1) (listenResult : pure listenResult)) bindResult)) result))"
final @?= ioOkResult Leaf
, testCase "connect to non-listening port returns error" $ do
final <- runIOSource $
unlines
[ "main = io (bind socket (result :"
, " matchResult"
, " (err rest : pure \"socket-err\")"
, " (sock rest : connect sock \"127.0.0.1\" 1)"
, " result))"
]
case final of
Fork Leaf (Fork _ Leaf) -> return ()
other -> assertFailure $ "Expected error result, got: " ++ show other
, testCase "accept and recv receive bytes from forked client" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "client = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " send sock [104 105]))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (client " ++ show port ++ ")) (_ :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr : recv clientSock 2)"
, " accepted))))))))))"
]
final @?= ioOkResult (ofBytes (BS.pack [104, 105]))
, testCase "client recv receives server response via accepted socket" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "serverTask = server :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr :"
, " bind (recv clientSock 4) (msgResult :"
, " preserveResult msgResult (_ rest :"
, " send clientSock [112 111 110 103])))"
, " accepted))"
, ""
, "clientTask = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " bind (send sock [112 105 110 103]) (_ :"
, " recv sock 4)))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (serverTask server)) (_ :"
, " clientTask " ++ show port ++ "))))))))"
]
final @?= ioOkResult (ofBytes (BS.pack [112, 111, 110, 103]))
, testCase "recv on closed peer returns connection closed" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "clientTask = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " closeSocket sock))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (clientTask " ++ show port ++ ")) (_ :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr :"
, " bind (yield) (_ :"
, " recv clientSock 1))"
, " accepted))))))))))"
]
final @?= ioErrResult "connection closed"
, testCase "accept invalid socket handle returns error" $ do
final <- runIOSource "main = io (accept (pair \"sock\" 99999))"
final @?= ioErrResult "invalid socket handle"
, testCase "recv invalid socket handle returns error" $ do
final <- runIOSource "main = io (recv (pair \"sock\" 99999) 1)"
final @?= ioErrResult "invalid socket handle"
, testCase "send invalid socket handle returns error" $ do
final <- runIOSource "main = io (send (pair \"sock\" 99999) [(1)])"
final @?= ioErrResult "invalid socket handle"
, testCase "getSocketName returns positive port after bind 0" $ do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" 0) (bindResult :"
, " preserveResult bindResult (_ rest :"
, " getSocketName server)))))"
]
case final of
Fork (Stem Leaf) (Fork val Leaf) ->
case toNumber val of
Right port | port > 0 -> return ()
Right 0 -> assertFailure "Expected positive port, got 0"
Left _ -> assertFailure $ "Expected numeric port, got: " ++ show val
other -> assertFailure $ "Expected ok result, got: " ++ show other
]
] ]
withFreePort :: (Int -> IO a) -> IO a
withFreePort action =
bracket
(NS.socket NS.AF_INET NS.Stream NS.defaultProtocol)
NS.close
(\s -> do
NS.setSocketOption s NS.ReuseAddr 1
NS.bind s (NS.SockAddrInet 0 (NS.tupleToHostAddress (127, 0, 0, 1)))
port <- NS.socketPort s
action (fromIntegral port))
runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T) runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T)
runIOSourceWith perms readerEnv initialState source = do runIOSourceWith perms readerEnv initialState source = do
ioEnv <- evaluateFile "./lib/io.tri" ioEnv <- evaluateFile "./lib/io.tri"
evalEnv <- evalTricuWithStore Nothing ioEnv (parseTricu source) sockEnv <- evaluateFile "./lib/socket.tri"
let combinedEnv = Map.union sockEnv ioEnv
evalEnv <- evalTricuWithStore Nothing combinedEnv (parseTricu source)
let fullTree = mainResult evalEnv let fullTree = mainResult evalEnv
result <- runIOWith perms readerEnv initialState fullTree result <- runIOWith perms readerEnv initialState fullTree
case result of case result of
@@ -1942,5 +2127,5 @@ runIOSourceWithEnv perms readerEnv source = fmap fst $ runIOSourceWith perms rea
ioOkResult :: T -> T ioOkResult :: T -> T
ioOkResult val = Fork (Stem Leaf) (Fork val Leaf) ioOkResult val = Fork (Stem Leaf) (Fork val Leaf)
ioErrResult :: Integer -> T ioErrResult :: String -> T
ioErrResult code = Fork Leaf (Fork (ofNumber code) Leaf) ioErrResult msg = Fork Leaf (Fork (ofString msg) Leaf)

View File

@@ -52,6 +52,7 @@ executable tricu
, megaparsec , megaparsec
, memory , memory
, mtl , mtl
, network
, servant , servant
, sqlite-simple , sqlite-simple
, stm , stm
@@ -102,6 +103,7 @@ benchmark tricu-bench
, megaparsec , megaparsec
, memory , memory
, mtl , mtl
, network
, sqlite-simple , sqlite-simple
, text , text
, time , time
@@ -148,6 +150,7 @@ test-suite tricu-tests
, megaparsec , megaparsec
, memory , memory
, mtl , mtl
, network
, servant , servant
, sqlite-simple , sqlite-simple
, stm , stm