Compare commits
2 Commits
593aa96193
...
020fa769a9
| Author | SHA1 | Date | |
|---|---|---|---|
| 020fa769a9 | |||
| 2e13583de3 |
@@ -26,7 +26,7 @@
|
||||
--
|
||||
-- File operations return a Result tree (see lib/base.tri):
|
||||
-- ok value -- pair true (pair value t)
|
||||
-- err code -- pair false (pair code t)
|
||||
-- err msg -- pair false (pair msg t)
|
||||
--
|
||||
-- Use onReadFile / onWriteFile for convenient branching.
|
||||
--
|
||||
|
||||
46
demos/interactionTrees/echo-server.tri
Normal file
46
demos/interactionTrees/echo-server.tri
Normal file
@@ -0,0 +1,46 @@
|
||||
!import "../../lib/base.tri" !Local
|
||||
!import "../../lib/io.tri" !Local
|
||||
!import "../../lib/socket.tri" !Local
|
||||
|
||||
-- Preserve the host-driver Result shape on error, run okCase on success.
|
||||
onOk = action okCase :
|
||||
bind action (result :
|
||||
matchResult
|
||||
(err rest : pure result)
|
||||
okCase
|
||||
result)
|
||||
|
||||
-- Convenience: print a string and continue.
|
||||
printLn = s : bind (putStr (append s "\n")) (_ : pure t)
|
||||
|
||||
-- Main accept+echo loop. Recursion via y.
|
||||
echoLoop = y (self server :
|
||||
bind (accept server) (acceptResult :
|
||||
matchResult
|
||||
(err rest :
|
||||
bind (printLn (append "accept error: " err)) (_ :
|
||||
self server))
|
||||
(accepted rest :
|
||||
matchPair
|
||||
(clientSock addr :
|
||||
bind (printLn (append "client from " addr)) (_ :
|
||||
bind (recv clientSock 4096) (msgResult :
|
||||
matchResult
|
||||
(err rest :
|
||||
bind (closeSocket clientSock) (_ :
|
||||
self server))
|
||||
(msg rest :
|
||||
bind (send clientSock msg) (_ :
|
||||
bind (closeSocket clientSock) (_ :
|
||||
self server)))
|
||||
msgResult)))
|
||||
accepted)
|
||||
acceptResult))
|
||||
|
||||
main = io (
|
||||
onOk socket (server rest :
|
||||
onOk (bindSocket server "127.0.0.1" 0) (_ rest :
|
||||
onOk (listen server 5) (_ rest :
|
||||
onOk (getSocketName server) (port rest :
|
||||
bind (printLn (append "Echo server listening on port " (showNumber port))) (_ :
|
||||
echoLoop server))))))
|
||||
@@ -74,7 +74,7 @@ succ = y (self :
|
||||
t))
|
||||
|
||||
ok = value rest : pair true (pair value rest)
|
||||
err = code rest : pair false (pair code rest)
|
||||
err = msg rest : pair false (pair msg rest)
|
||||
|
||||
matchResult = (errCase okCase result :
|
||||
matchPair
|
||||
|
||||
@@ -63,24 +63,24 @@ onWriteFile = (path contents errCase okCase :
|
||||
|
||||
readFileOrPrintError = (path okCase :
|
||||
onReadFile path
|
||||
(err rest : putStrLn "Read failed")
|
||||
(err rest : putStrLn (append "Read failed: " err))
|
||||
okCase)
|
||||
|
||||
writeFileOrPrintError = (path contents okCase :
|
||||
onWriteFile path contents
|
||||
(err rest : putStrLn "Write failed")
|
||||
(err rest : putStrLn (append "Write failed: " err))
|
||||
okCase)
|
||||
|
||||
copyFile = (src dst :
|
||||
bind (readFile src)
|
||||
(result :
|
||||
matchResult
|
||||
(err rest : putStrLn "Read failed")
|
||||
(err rest : putStrLn (append "Read failed: " err))
|
||||
(contents rest :
|
||||
bind (writeFile dst contents)
|
||||
(wr :
|
||||
matchResult
|
||||
(err rest : putStrLn "Write failed")
|
||||
(err rest : putStrLn (append "Write failed: " err))
|
||||
(ok rest : pure t)
|
||||
wr))
|
||||
result))
|
||||
|
||||
63
lib/socket.tri
Normal file
63
lib/socket.tri
Normal file
@@ -0,0 +1,63 @@
|
||||
!import "base.tri" !Local
|
||||
!import "io.tri" !Local
|
||||
|
||||
-- Socket primitives for the IO driver.
|
||||
-- All actions return a Result tree (see lib/base.tri):
|
||||
-- ok value -- pair true (pair value t)
|
||||
-- err msg -- pair false (pair msg t)
|
||||
|
||||
socket = pair 70 t
|
||||
closeSocket = sock : pair 71 sock
|
||||
bindSocket = sock addr port : pair 72 (pair sock (pair addr port))
|
||||
listen = sock backlog : pair 73 (pair sock backlog)
|
||||
accept = sock : pair 74 sock
|
||||
connect = sock addr port : pair 75 (pair sock (pair addr port))
|
||||
recv = sock maxBytes : pair 76 (pair sock maxBytes)
|
||||
send = sock bytes : pair 77 (pair sock bytes)
|
||||
getSocketName = sock : pair 78 sock
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Convenience helpers
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
onSocket = (action errCase okCase :
|
||||
bind action (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Create a listening socket bound to an address and port.
|
||||
-- Returns ok listenSocket or err message.
|
||||
listenSocket = addr port backlog :
|
||||
bind (socket) (result :
|
||||
matchResult
|
||||
(err rest : pure (err "socket creation failed"))
|
||||
(sock rest :
|
||||
bind (bindSocket sock addr port) (bindResult :
|
||||
matchResult
|
||||
(err rest : pure (err "bind failed"))
|
||||
(_ rest :
|
||||
bind (listen sock backlog) (listenResult :
|
||||
matchResult
|
||||
(err rest : pure (err "listen failed"))
|
||||
(_ rest : pure (ok sock))
|
||||
listenResult))
|
||||
bindResult))
|
||||
result)
|
||||
|
||||
-- Accept a connection and return (clientSocket, peerAddr).
|
||||
-- The returned peerAddr is a string like "127.0.0.1:8080".
|
||||
onAccept = (sock errCase okCase :
|
||||
bind (accept sock) (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Receive all available bytes up to maxBytes.
|
||||
onRecv = (sock maxBytes errCase okCase :
|
||||
bind (recv sock maxBytes) (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Send bytes and return number of bytes sent.
|
||||
onSend = (sock bytes errCase okCase :
|
||||
bind (send sock bytes) (result :
|
||||
matchResult errCase okCase result))
|
||||
|
||||
-- Close a socket, ignoring errors.
|
||||
closeSocket_ = sock : bind (closeSocket sock) (_ : pure t)
|
||||
45
src/Eval.hs
45
src/Eval.hs
@@ -33,16 +33,15 @@ type Uses = [Bool]
|
||||
|
||||
evalSingle :: Env -> TricuAST -> Env
|
||||
evalSingle env term
|
||||
| SDef name [] body <- term
|
||||
= case Map.lookup name env of
|
||||
| SDef name params body <- term
|
||||
= let res = evalASTSync env (if null params then body else SLambda params body)
|
||||
in case Map.lookup name env of
|
||||
Just existingValue
|
||||
| existingValue == evalASTSync env body -> env
|
||||
| existingValue == res -> env
|
||||
| otherwise
|
||||
-> let res = evalASTSync env body
|
||||
in Map.insert "!result" res (Map.insert name res env)
|
||||
-> Map.insert "!result" res (Map.insert name res env)
|
||||
Nothing
|
||||
-> let res = evalASTSync env body
|
||||
in Map.insert "!result" res (Map.insert name res env)
|
||||
-> Map.insert "!result" res (Map.insert name res env)
|
||||
| SApp func arg <- term
|
||||
= let res = apply (evalASTSync env func) (evalASTSync env arg)
|
||||
in Map.insert "!result" res env
|
||||
@@ -57,7 +56,7 @@ evalSingle env term
|
||||
in Map.insert "!result" res env
|
||||
|
||||
evalTricu :: Env -> [TricuAST] -> Env
|
||||
evalTricu env x = go env (reorderDefs env x)
|
||||
evalTricu env x = go env (reorderDefs env (map recoverParams x))
|
||||
where
|
||||
go env' [] = env'
|
||||
go env' [def] =
|
||||
@@ -102,11 +101,10 @@ evalASTWithEnv mconn localEnv ast = do
|
||||
let combinedEnv = Map.union localEnv storeEnv
|
||||
return $ evalASTSync combinedEnv ast
|
||||
|
||||
-- | Store-aware version of 'evalSingle'.
|
||||
evalSingleWithStore :: Maybe Connection -> Env -> TricuAST -> IO Env
|
||||
evalSingleWithStore mconn env term
|
||||
| SDef name [] body <- term = do
|
||||
res <- evalASTWithEnv mconn env body
|
||||
| SDef name params body <- term = do
|
||||
res <- evalASTWithEnv mconn env (if null params then body else SLambda params body)
|
||||
case Map.lookup name env of
|
||||
Just existingValue
|
||||
| existingValue == res -> return env
|
||||
@@ -116,11 +114,8 @@ evalSingleWithStore mconn env term
|
||||
res <- evalASTWithEnv mconn env term
|
||||
return $ Map.insert "!result" res env
|
||||
|
||||
-- | Store-aware version of 'evalTricu'. Does not preload the entire
|
||||
-- content store; terms are resolved on demand as variables are
|
||||
-- encountered.
|
||||
evalTricuWithStore :: Maybe Connection -> Env -> [TricuAST] -> IO Env
|
||||
evalTricuWithStore mconn env x = go env (reorderDefs env x)
|
||||
evalTricuWithStore mconn env x = go env (reorderDefs env (map recoverParams x))
|
||||
where
|
||||
go env' [] = return env'
|
||||
go env' [def] = do
|
||||
@@ -130,6 +125,10 @@ evalTricuWithStore mconn env x = go env (reorderDefs env x)
|
||||
updatedEnv <- evalSingleWithStore mconn env' def
|
||||
evalTricuWithStore mconn updatedEnv xs
|
||||
|
||||
recoverParams :: TricuAST -> TricuAST
|
||||
recoverParams (SDef name [] (SLambda params body)) = SDef name params body
|
||||
recoverParams term = term
|
||||
|
||||
collectVarNames :: TricuAST -> [(String, Maybe String)]
|
||||
collectVarNames = go []
|
||||
where
|
||||
@@ -189,6 +188,7 @@ elimLambda = go
|
||||
| isSList term = slistTransform term
|
||||
| otherwise = term
|
||||
|
||||
etaReduction (SLambda [v] (SVar x Nothing)) = v == x
|
||||
etaReduction (SLambda [v] (SApp f (SVar x Nothing))) = v == x && not (usesBinder v f)
|
||||
etaReduction _ = False
|
||||
|
||||
@@ -209,8 +209,9 @@ elimLambda = go
|
||||
application (SApp _ _) = True
|
||||
application _ = False
|
||||
|
||||
etaReduceResult (SLambda [_] (SVar _ Nothing)) = _I
|
||||
etaReduceResult (SLambda [_] (SApp f _)) = f
|
||||
etaReduceResult _ = error "etaReduceResult: expected SLambda [v] (SApp f _)"
|
||||
etaReduceResult _ = error "etaReduceResult: unexpected shape"
|
||||
|
||||
lambdaListResult (SLambda [v] (SList xs)) =
|
||||
SLambda [v] (foldr wrapTLeaf TLeaf xs)
|
||||
@@ -254,12 +255,12 @@ composeBody f g x = SApp (SVar f Nothing) (SApp (SVar g Nothing) (SVar x Nothin
|
||||
isFree :: String -> TricuAST -> Bool
|
||||
isFree x t = Set.member x (freeVars t)
|
||||
|
||||
-- Keep old freeVars for compatibility with reorderDefs which still uses TricuAST
|
||||
freeVars :: TricuAST -> Set String
|
||||
freeVars (SVar v Nothing) = Set.singleton v
|
||||
freeVars (SVar v (Just _)) = Set.singleton v
|
||||
freeVars (SApp t u) = Set.union (freeVars t) (freeVars u)
|
||||
freeVars (SLambda vs body) = Set.difference (freeVars body) (Set.fromList vs)
|
||||
freeVars (SDef _ params body) = Set.difference (freeVars body) (Set.fromList params)
|
||||
freeVars (TStem t) = freeVars t
|
||||
freeVars (TFork t u) = Set.union (freeVars t) (freeVars u)
|
||||
freeVars (SList xs) = foldMap freeVars xs
|
||||
@@ -275,7 +276,7 @@ reorderDefs env defs
|
||||
(defsOnly, others) = partition isDef defs
|
||||
defNames = [ name | SDef name _ _ <- defsOnly ]
|
||||
|
||||
defsWithFreeVars = [(def, freeVars body) | def@(SDef _ _ body) <- defsOnly]
|
||||
defsWithFreeVars = [(def, freeVars def) | def <- defsOnly]
|
||||
|
||||
graph = buildDepGraph defsOnly
|
||||
sortedDefs = sortDeps graph
|
||||
@@ -298,8 +299,8 @@ buildDepGraph topDefs
|
||||
"Conflicting definitions detected: " ++ show conflictingDefs
|
||||
| otherwise =
|
||||
Map.fromList
|
||||
[ (name, depends topDefs (SDef name [] body))
|
||||
| SDef name _ body <- topDefs]
|
||||
[ (name, depends topDefs def)
|
||||
| def@(SDef name _ _) <- topDefs]
|
||||
where
|
||||
defsMap = Map.fromListWith (++)
|
||||
[(name, [(name, body)]) | SDef name _ body <- topDefs]
|
||||
@@ -329,10 +330,10 @@ sortDeps graph = go [] Set.empty (Map.keys graph)
|
||||
notReady
|
||||
|
||||
depends :: [TricuAST] -> TricuAST -> Set.Set String
|
||||
depends topDefs (SDef _ _ body) =
|
||||
depends topDefs def@(SDef _ _ _) =
|
||||
Set.intersection
|
||||
(Set.fromList [n | SDef n _ _ <- topDefs])
|
||||
(freeVars body)
|
||||
(freeVars def)
|
||||
depends _ _ = Set.empty
|
||||
|
||||
result :: Env -> T
|
||||
|
||||
418
src/IODriver.hs
418
src/IODriver.hs
@@ -12,7 +12,7 @@ import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes,
|
||||
import qualified Data.ByteString as BS
|
||||
import System.IO (putStr, getLine)
|
||||
import qualified System.IO as IO
|
||||
import Control.Exception (try, IOException, SomeException)
|
||||
import Control.Exception (try, catch, IOException, SomeException)
|
||||
import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError)
|
||||
import Data.List (isPrefixOf)
|
||||
import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories)
|
||||
@@ -27,6 +27,8 @@ import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar,
|
||||
import qualified Data.Set as Set
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Foldable as Fold
|
||||
import qualified Network.Socket as NS
|
||||
import qualified Network.Socket.ByteString as NSB
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Permissions
|
||||
@@ -95,67 +97,36 @@ data Machine = Machine
|
||||
--
|
||||
-- Runtime protocol errors are returned as direct values via errResult.
|
||||
|
||||
-- Error code ranges:
|
||||
-- 1-19 host IO / filesystem errors
|
||||
-- 20-39 policy / permission errors
|
||||
-- 40-59 protocol / decode / type errors
|
||||
-- 60-79 async errors
|
||||
-- 80-99 scheduler / runtime errors
|
||||
|
||||
-- Host IO / filesystem errors (1-19)
|
||||
errDoesNotExist, errPermission, errAlreadyExists, errIOOther :: Integer
|
||||
errDoesNotExist = 1
|
||||
errPermission = 2
|
||||
errAlreadyExists = 3
|
||||
errIOOther = 4
|
||||
|
||||
-- Policy / permission errors (20-39)
|
||||
errPolicyDeny :: Integer
|
||||
errPolicyDeny = 20
|
||||
|
||||
-- Protocol / decode / type errors (40-59)
|
||||
errInvalidAction, errInvalidString :: Integer
|
||||
errInvalidAction = 40
|
||||
errInvalidString = 41
|
||||
|
||||
-- Async errors (60-79)
|
||||
errInvalidHandle, errSelfAwait, errInvalidSleep, errCyclicAwait :: Integer
|
||||
errInvalidHandle = 60
|
||||
errSelfAwait = 61
|
||||
errInvalidSleep = 62
|
||||
errCyclicAwait = 63
|
||||
|
||||
-- Scheduler / runtime errors (80-99)
|
||||
errDeadlock :: Integer
|
||||
errDeadlock = 80
|
||||
|
||||
ioErrorCode :: IOException -> Integer
|
||||
ioErrorCode e
|
||||
| isDoesNotExistError e = errDoesNotExist
|
||||
| isPermissionError e = errPermission
|
||||
| isAlreadyExistsError e = errAlreadyExists
|
||||
| otherwise = errIOOther
|
||||
|
||||
okResult :: T -> T
|
||||
okResult val = Fork (Stem Leaf) (Fork val Leaf)
|
||||
|
||||
errResult :: Integer -> T
|
||||
errResult code = Fork Leaf (Fork (ofNumber code) Leaf)
|
||||
errResult :: String -> T
|
||||
errResult msg = Fork Leaf (Fork (ofString msg) Leaf)
|
||||
|
||||
pureAction :: T -> T
|
||||
pureAction x = Fork (ofNumber 0) x
|
||||
|
||||
invalidAsyncHandleResult :: T
|
||||
invalidAsyncHandleResult = errResult errInvalidHandle
|
||||
invalidAsyncHandleResult = errResult "invalid task handle"
|
||||
|
||||
invalidSocketHandleResult :: T
|
||||
invalidSocketHandleResult = errResult "invalid socket handle"
|
||||
|
||||
selfAwaitResult :: T
|
||||
selfAwaitResult = errResult errSelfAwait
|
||||
selfAwaitResult = errResult "self await"
|
||||
|
||||
deadlockResult :: T
|
||||
deadlockResult = errResult errDeadlock
|
||||
deadlockResult = errResult "deadlock"
|
||||
|
||||
invalidSleepResult :: T
|
||||
invalidSleepResult = errResult errInvalidSleep
|
||||
invalidSleepResult = errResult "invalid sleep"
|
||||
|
||||
ioErrorString :: IOException -> String
|
||||
ioErrorString e
|
||||
| isDoesNotExistError e = "does not exist"
|
||||
| isPermissionError e = "permission denied"
|
||||
| isAlreadyExistsError e = "already exists"
|
||||
| otherwise = "io error"
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Task identity and handles
|
||||
@@ -179,6 +150,45 @@ decodeTaskHandle tree =
|
||||
_ ->
|
||||
Left "invalid task handle"
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Socket identity and handles
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
newtype SockId = SockId Integer
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
sockHandle :: SockId -> T
|
||||
sockHandle (SockId n) =
|
||||
Fork (ofString "sock") (ofNumber n)
|
||||
|
||||
decodeSockHandle :: T -> Either String SockId
|
||||
decodeSockHandle tree =
|
||||
case tree of
|
||||
Fork tag nTree -> do
|
||||
tagString <- toString tag
|
||||
if tagString == "sock"
|
||||
then SockId <$> toNumber nTree
|
||||
else Left "invalid socket handle tag"
|
||||
_ ->
|
||||
Left "invalid socket handle"
|
||||
|
||||
getSocketPort :: NS.Socket -> IO (Maybe Integer)
|
||||
getSocketPort sock = do
|
||||
addr <- NS.getSocketName sock
|
||||
case addr of
|
||||
NS.SockAddrInet p _ -> return (Just (fromIntegral p))
|
||||
NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p))
|
||||
_ -> return Nothing
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Socket registry
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
data SocketRegistry = SocketRegistry
|
||||
{ sockMap :: Map SockId NS.Socket
|
||||
, sockNextId :: Integer
|
||||
}
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Free-monad action AST
|
||||
-- ---------------------------------------------------------------------------
|
||||
@@ -200,6 +210,15 @@ data Action
|
||||
| AAwait T
|
||||
| AYield
|
||||
| ASleep T
|
||||
| ASocket
|
||||
| ACloseSocket T
|
||||
| ABindSocket T T T
|
||||
| AListen T T
|
||||
| AAccept T
|
||||
| AConnect T T T
|
||||
| ARecv T T
|
||||
| ASend T T
|
||||
| AGetSocketName T
|
||||
deriving (Show)
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
@@ -234,6 +253,19 @@ tagAwait = 61
|
||||
tagYield = 62
|
||||
tagSleep = 63
|
||||
|
||||
tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer
|
||||
tagSocket = 70
|
||||
tagCloseSocket = 71
|
||||
tagBindSocket = 72
|
||||
tagListen = 73
|
||||
tagAccept = 74
|
||||
|
||||
tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer
|
||||
tagConnect = 75
|
||||
tagRecv = 76
|
||||
tagSend = 77
|
||||
tagGetSocketName = 78
|
||||
|
||||
data Step
|
||||
= Halt Runtime T
|
||||
| Continue Machine
|
||||
@@ -313,6 +345,43 @@ decodeAction tree =
|
||||
Right n | n == tagSleep ->
|
||||
Right (ASleep payload)
|
||||
|
||||
Right n | n == tagSocket ->
|
||||
Right ASocket
|
||||
|
||||
Right n | n == tagCloseSocket ->
|
||||
Right (ACloseSocket payload)
|
||||
|
||||
Right n | n == tagBindSocket ->
|
||||
case payload of
|
||||
Fork sock (Fork addr port) -> Right (ABindSocket sock addr port)
|
||||
_ -> Left "Invalid BindSocket: expected pair sock (pair addr port)"
|
||||
|
||||
Right n | n == tagListen ->
|
||||
case payload of
|
||||
Fork sock backlog -> Right (AListen sock backlog)
|
||||
_ -> Left "Invalid Listen: expected pair sock backlog"
|
||||
|
||||
Right n | n == tagAccept ->
|
||||
Right (AAccept payload)
|
||||
|
||||
Right n | n == tagConnect ->
|
||||
case payload of
|
||||
Fork sock (Fork addr port) -> Right (AConnect sock addr port)
|
||||
_ -> Left "Invalid Connect: expected pair sock (pair addr port)"
|
||||
|
||||
Right n | n == tagRecv ->
|
||||
case payload of
|
||||
Fork sock maxBytes -> Right (ARecv sock maxBytes)
|
||||
_ -> Left "Invalid Recv: expected pair sock maxBytes"
|
||||
|
||||
Right n | n == tagSend ->
|
||||
case payload of
|
||||
Fork sock bytes -> Right (ASend sock bytes)
|
||||
_ -> Left "Invalid Send: expected pair sock bytes"
|
||||
|
||||
Right n | n == tagGetSocketName ->
|
||||
Right (AGetSocketName payload)
|
||||
|
||||
Right n ->
|
||||
Left $ "Unknown IO action tag: " ++ show n
|
||||
|
||||
@@ -346,11 +415,11 @@ finishValue machine value =
|
||||
, machineFrames = rest
|
||||
})
|
||||
|
||||
stepMachine :: Machine -> IO Step
|
||||
stepMachine machine =
|
||||
stepMachine :: TVar SocketRegistry -> Machine -> IO Step
|
||||
stepMachine sockVar machine =
|
||||
case decodeAction (machineCurrent machine) of
|
||||
Right action -> dispatch action
|
||||
Left _ -> finishValue machine (errResult errInvalidAction)
|
||||
Left _ -> finishValue machine (errResult "invalid action")
|
||||
where
|
||||
dispatch action = case action of
|
||||
APure val ->
|
||||
@@ -367,14 +436,14 @@ stepMachine machine =
|
||||
Right s ->
|
||||
pure (AsyncAction (putStr s >> pure Leaf) machine)
|
||||
Left _ ->
|
||||
finishValue machine (errResult errInvalidString)
|
||||
finishValue machine (errResult "invalid string")
|
||||
|
||||
APutBytes bs ->
|
||||
case decodeBytes bs "PutBytes" of
|
||||
Right b ->
|
||||
pure (AsyncAction (BS.putStr b >> pure Leaf) machine)
|
||||
Left _ ->
|
||||
finishValue machine (errResult errInvalidString)
|
||||
finishValue machine (errResult "invalid bytes")
|
||||
|
||||
AGetLine ->
|
||||
pure (AsyncAction (ofString <$> getLine) machine)
|
||||
@@ -386,7 +455,7 @@ stepMachine machine =
|
||||
case mDeny of
|
||||
Just denied -> finishValue machine denied
|
||||
Nothing -> pure (AsyncAction (tryReadFile p) machine)
|
||||
Left _ -> finishValue machine (errResult errInvalidString)
|
||||
Left _ -> finishValue machine (errResult "invalid string")
|
||||
|
||||
AWriteFile path contents ->
|
||||
case decodeString path "WriteFile" of
|
||||
@@ -397,8 +466,8 @@ stepMachine machine =
|
||||
case mDeny of
|
||||
Just denied -> finishValue machine denied
|
||||
Nothing -> pure (AsyncAction (tryWriteFile p c) machine)
|
||||
Left _ -> finishValue machine (errResult errInvalidString)
|
||||
Left _ -> finishValue machine (errResult errInvalidString)
|
||||
Left _ -> finishValue machine (errResult "invalid string")
|
||||
Left _ -> finishValue machine (errResult "invalid string")
|
||||
|
||||
AWriteBytes path contents ->
|
||||
case decodeString path "WriteBytes" of
|
||||
@@ -409,8 +478,8 @@ stepMachine machine =
|
||||
case mDeny of
|
||||
Just denied -> finishValue machine denied
|
||||
Nothing -> pure (AsyncAction (tryWriteFileBytes p c) machine)
|
||||
Left _ -> finishValue machine (errResult errInvalidString)
|
||||
Left _ -> finishValue machine (errResult errInvalidString)
|
||||
Left _ -> finishValue machine (errResult "invalid bytes")
|
||||
Left _ -> finishValue machine (errResult "invalid string")
|
||||
|
||||
AAsk ->
|
||||
finishValue machine (rtEnv (machineRuntime machine))
|
||||
@@ -453,6 +522,207 @@ stepMachine machine =
|
||||
_ ->
|
||||
finishValue machine invalidSleepResult
|
||||
|
||||
ASocket -> do
|
||||
result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket)
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right sock -> do
|
||||
NS.setSocketOption sock NS.ReuseAddr 1
|
||||
sid <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
let sid = SockId next
|
||||
writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1))
|
||||
return sid
|
||||
finishValue machine (okResult (sockHandle sid))
|
||||
|
||||
ACloseSocket sockTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
case Map.lookup sid m of
|
||||
Nothing -> return Nothing
|
||||
Just sock -> do
|
||||
writeTVar sockVar (SocketRegistry (Map.delete sid m) next)
|
||||
return (Just sock)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
NS.close sock
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
ABindSocket sockTree addrTree portTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeString addrTree "BindSocket" of
|
||||
Left _ -> finishValue machine (errResult "invalid address")
|
||||
Right addrStr ->
|
||||
case toNumber portTree of
|
||||
Left _ -> finishValue machine (errResult "invalid port")
|
||||
Right port -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
result <- try (do
|
||||
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
||||
(Just addrStr)
|
||||
(Just (show port))
|
||||
let serverAddr = head addrInfo
|
||||
NS.bind sock (NS.addrAddress serverAddr)
|
||||
) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
AListen sockTree backlogTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case toNumber backlogTree of
|
||||
Left _ -> finishValue machine (errResult "invalid backlog")
|
||||
Right backlog -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
finishValue machine (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
finishValue machine (okResult Leaf)
|
||||
|
||||
AAccept listenTree ->
|
||||
case decodeSockHandle listenTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right listenSid ->
|
||||
pure (AsyncAction (do
|
||||
mListenSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup listenSid m)
|
||||
case mListenSock of
|
||||
Nothing -> return (errResult "invalid socket handle")
|
||||
Just listenSock -> do
|
||||
result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr))
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right (clientSock, addr) -> do
|
||||
clientSid <- atomically $ do
|
||||
SocketRegistry m next <- readTVar sockVar
|
||||
let sid = SockId next
|
||||
writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1))
|
||||
return sid
|
||||
let addrStr = case addr of
|
||||
NS.SockAddrInet p h ->
|
||||
let (a,b,c,d) = NS.hostAddressToTuple h
|
||||
in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p
|
||||
_ -> show addr
|
||||
return (okResult (Fork (sockHandle clientSid) (ofString addrStr)))
|
||||
) machine)
|
||||
|
||||
AConnect sockTree addrTree portTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeString addrTree "Connect" of
|
||||
Left _ -> finishValue machine (errResult "invalid address")
|
||||
Right addrStr ->
|
||||
case toNumber portTree of
|
||||
Left _ -> finishValue machine (errResult "invalid port")
|
||||
Right port -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (do
|
||||
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
||||
(Just addrStr)
|
||||
(Just (show port))
|
||||
let serverAddr = head addrInfo
|
||||
NS.connect sock (NS.addrAddress serverAddr)
|
||||
) :: IO (Either SomeException ())
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right () ->
|
||||
return (okResult Leaf)
|
||||
) machine)
|
||||
|
||||
ARecv sockTree maxBytesTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case toNumber maxBytesTree of
|
||||
Left _ -> finishValue machine (errResult "invalid maxBytes")
|
||||
Right maxBytes -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString)
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right bs ->
|
||||
if BS.null bs
|
||||
then return (errResult "connection closed")
|
||||
else return (okResult (ofBytes bs))
|
||||
) machine)
|
||||
|
||||
AGetSocketName sockTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock -> do
|
||||
mPort <- getSocketPort sock
|
||||
case mPort of
|
||||
Just port -> finishValue machine (okResult (ofNumber port))
|
||||
Nothing -> finishValue machine (errResult "io error: could not get socket name")
|
||||
|
||||
ASend sockTree bytesTree ->
|
||||
case decodeSockHandle sockTree of
|
||||
Left _ -> finishValue machine invalidSocketHandleResult
|
||||
Right sid ->
|
||||
case decodeBytes bytesTree "Send" of
|
||||
Left _ -> finishValue machine (errResult "invalid bytes")
|
||||
Right bs -> do
|
||||
mSock <- atomically $ do
|
||||
SocketRegistry m _ <- readTVar sockVar
|
||||
return (Map.lookup sid m)
|
||||
case mSock of
|
||||
Nothing -> finishValue machine invalidSocketHandleResult
|
||||
Just sock ->
|
||||
pure (AsyncAction (do
|
||||
result <- try (NSB.send sock bs) :: IO (Either SomeException Int)
|
||||
case result of
|
||||
Left e ->
|
||||
return (errResult ("io error: " ++ show e))
|
||||
Right sent ->
|
||||
return (okResult (ofNumber (fromIntegral sent)))
|
||||
) machine)
|
||||
|
||||
-- Permission and IO helpers
|
||||
checkReadPerm p =
|
||||
if allowReadAll (rtPerms (machineRuntime machine))
|
||||
@@ -480,7 +750,7 @@ stepMachine machine =
|
||||
then return Nothing
|
||||
else return $ Just policyErrResult
|
||||
|
||||
policyErrResult = errResult errPolicyDeny
|
||||
policyErrResult = errResult "permission denied"
|
||||
|
||||
canonicalizeSafe :: FilePath -> IO (Either String FilePath)
|
||||
canonicalizeSafe p = do
|
||||
@@ -534,19 +804,19 @@ stepMachine machine =
|
||||
result <- try (BS.readFile path) :: IO (Either IOException BS.ByteString)
|
||||
case result of
|
||||
Right content -> return $ okResult (ofBytes content)
|
||||
Left e -> return $ errResult (ioErrorCode e)
|
||||
Left e -> return $ errResult (ioErrorString e)
|
||||
|
||||
tryWriteFile path contents = do
|
||||
result <- try (IO.writeFile path contents) :: IO (Either IOException ())
|
||||
case result of
|
||||
Right () -> return $ okResult Leaf
|
||||
Left e -> return $ errResult (ioErrorCode e)
|
||||
Left e -> return $ errResult (ioErrorString e)
|
||||
|
||||
tryWriteFileBytes path contents = do
|
||||
result <- try (BS.writeFile path contents) :: IO (Either IOException ())
|
||||
case result of
|
||||
Right () -> return $ okResult Leaf
|
||||
Left e -> return $ errResult (ioErrorCode e)
|
||||
Left e -> return $ errResult (ioErrorString e)
|
||||
|
||||
decodeString t ctx =
|
||||
case toString t of
|
||||
@@ -577,6 +847,8 @@ data Scheduler = Scheduler
|
||||
, schedulerSleepQueue :: Map UTCTime (Set TaskId)
|
||||
, schedulerAsyncCompleted :: TVar (Map TaskId T)
|
||||
, schedulerCompleted :: Map TaskId (T, T)
|
||||
, schedulerSockets :: TVar SocketRegistry
|
||||
, schedulerNextSockId :: Integer
|
||||
}
|
||||
|
||||
instance Show Scheduler where
|
||||
@@ -587,10 +859,12 @@ instance Show Scheduler where
|
||||
++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s)
|
||||
++ ", schedulerAsyncCompleted = <tvar>"
|
||||
++ ", schedulerCompleted = " ++ show (schedulerCompleted s)
|
||||
++ ", schedulerSockets = <tvar>"
|
||||
++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s)
|
||||
++ " }"
|
||||
|
||||
initialScheduler :: TVar (Map TaskId T) -> Machine -> Scheduler
|
||||
initialScheduler asyncVar mainMachine =
|
||||
initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler
|
||||
initialScheduler asyncVar sockVar mainMachine =
|
||||
Scheduler
|
||||
{ schedulerNextTaskId = 1
|
||||
, schedulerRunnable = Seq.singleton (TaskId 0)
|
||||
@@ -599,6 +873,8 @@ initialScheduler asyncVar mainMachine =
|
||||
, schedulerSleepQueue = Map.empty
|
||||
, schedulerAsyncCompleted = asyncVar
|
||||
, schedulerCompleted = Map.empty
|
||||
, schedulerSockets = sockVar
|
||||
, schedulerNextSockId = 0
|
||||
}
|
||||
|
||||
runtimeOfStatus :: TaskStatus -> Maybe Runtime
|
||||
@@ -732,7 +1008,7 @@ handleStep currentId (AwaitRequested targetId machine) scheduler
|
||||
|
||||
Just (BlockedOn nextId _) ->
|
||||
if wouldCycle targetId currentId (schedulerTasks scheduler)
|
||||
then resumeCurrentWith currentId (errResult errCyclicAwait) machine scheduler
|
||||
then resumeCurrentWith currentId (errResult "cyclic await") machine scheduler
|
||||
else block
|
||||
|
||||
Just _ -> block
|
||||
@@ -759,8 +1035,11 @@ handleStep taskId (SleepRequested ms machine) scheduler = do
|
||||
|
||||
handleStep taskId (AsyncAction ioAction machine) scheduler = do
|
||||
_ <- forkIO $ do
|
||||
result <- ioAction
|
||||
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId result)
|
||||
result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e))
|
||||
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $
|
||||
case result of
|
||||
Right val -> val
|
||||
Left msg -> errResult msg)
|
||||
pure scheduler
|
||||
{ schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler)
|
||||
}
|
||||
@@ -818,7 +1097,7 @@ schedulerStep scheduler = do
|
||||
taskId :< restQueue ->
|
||||
case Map.lookup taskId (schedulerTasks scheduler1) of
|
||||
Just (Runnable machine) -> do
|
||||
step <- stepMachine machine
|
||||
step <- stepMachine (schedulerSockets scheduler1) machine
|
||||
handleStep taskId step scheduler1 { schedulerRunnable = restQueue }
|
||||
|
||||
_ ->
|
||||
@@ -843,6 +1122,7 @@ runIOWith perms env initialState action =
|
||||
Left err -> pure (Left err)
|
||||
Right (_, action') -> do
|
||||
asyncVar <- newTVarIO Map.empty
|
||||
sockVar <- newTVarIO (SocketRegistry Map.empty 0)
|
||||
let initialMachine = Machine
|
||||
{ machineRuntime = Runtime
|
||||
{ rtPerms = perms
|
||||
@@ -852,7 +1132,7 @@ runIOWith perms env initialState action =
|
||||
, machineCurrent = action'
|
||||
, machineFrames = []
|
||||
}
|
||||
Right <$> runScheduler (initialScheduler asyncVar initialMachine)
|
||||
Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine)
|
||||
|
||||
runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T)
|
||||
runIOWithEnv perms env action = do
|
||||
|
||||
211
test/Spec.hs
211
test/Spec.hs
@@ -10,7 +10,8 @@ import Wire
|
||||
import ContentStore
|
||||
import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms)
|
||||
|
||||
import Control.Exception (evaluate, try, SomeException)
|
||||
import Control.Exception (bracket, evaluate, try, SomeException)
|
||||
import qualified Network.Socket as NS
|
||||
import Control.Monad (forM_)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import System.IO.Temp (withSystemTempDirectory)
|
||||
@@ -1553,15 +1554,15 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
-- Malformed action tests
|
||||
, testCase "unknown IO action tag returns err result" $ do
|
||||
final <- runIOSource "main = io (pair 99 t)"
|
||||
final @?= ioErrResult 40
|
||||
final @?= ioErrResult "invalid action"
|
||||
|
||||
, testCase "malformed Bind returns err result" $ do
|
||||
final <- runIOSource "main = io (pair 1 t)"
|
||||
final @?= ioErrResult 40
|
||||
final @?= ioErrResult "invalid action"
|
||||
|
||||
, testCase "malformed ReadFile payload returns err result" $ do
|
||||
final <- runIOSource "main = io (readFile (t t))"
|
||||
final @?= ioErrResult 41
|
||||
final @?= ioErrResult "invalid string"
|
||||
|
||||
-- Permission tests
|
||||
, testCase "allowed read path succeeds" $
|
||||
@@ -1586,7 +1587,7 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
unlines
|
||||
[ "main = io (readFile \"" ++ deniedPath ++ "\")"
|
||||
]
|
||||
result @?= ioErrResult 20
|
||||
result @?= ioErrResult "permission denied"
|
||||
|
||||
, testCase "writeFile denied path returns err result" $
|
||||
withSystemTempDirectory "tricu-io-write-denied" $ \dir -> do
|
||||
@@ -1597,7 +1598,7 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
unlines
|
||||
[ "main = io (writeFile \"" ++ deniedPath ++ "\" \"x\")"
|
||||
]
|
||||
result @?= ioErrResult 20
|
||||
result @?= ioErrResult "permission denied"
|
||||
|
||||
, testCase "path prefix does not allow prefix bypass" $
|
||||
withSystemTempDirectory "tricu-io-prefix" $ \dir -> do
|
||||
@@ -1611,7 +1612,7 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
unlines
|
||||
[ "main = io (readFile \"" ++ bypassPath ++ "\")"
|
||||
]
|
||||
result @?= ioErrResult 20
|
||||
result @?= ioErrResult "permission denied"
|
||||
|
||||
-- Pure test
|
||||
, testCase "pure performs no effects" $ do
|
||||
@@ -1820,14 +1821,14 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
unlines
|
||||
[ "main = io (await (pair \"task\" 0))"
|
||||
]
|
||||
final @?= ioErrResult 61
|
||||
final @?= ioErrResult "self await"
|
||||
|
||||
, testCase "await invalid handle returns async error" $ do
|
||||
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf $
|
||||
unlines
|
||||
[ "main = io (await 123)"
|
||||
]
|
||||
final @?= ioErrResult 60
|
||||
final @?= ioErrResult "invalid task handle"
|
||||
|
||||
, testCase "yield returns unit and resumes continuation" $ do
|
||||
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf $
|
||||
@@ -1890,7 +1891,7 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
[ "main = io (bind (fork (await (pair \"task\" 0))) (h :"
|
||||
, " await h))"
|
||||
]
|
||||
final @?= ioErrResult 63
|
||||
final @?= ioErrResult "cyclic await"
|
||||
|
||||
, testCase "writeBytes and readFile roundtrip binary data" $
|
||||
withSystemTempDirectory "tricu-io-bytes" $ \dir -> do
|
||||
@@ -1918,12 +1919,196 @@ ioDriverTests = testGroup "IO driver tests"
|
||||
build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))"
|
||||
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")")
|
||||
final @?= ofString "done"
|
||||
|
||||
, testGroup "Socket primitives"
|
||||
[ testCase "socket returns ok result with valid handle" $ do
|
||||
final <- runIOSource "main = io socket"
|
||||
final @?= ioOkResult (Fork (ofString "sock") (ofNumber 0))
|
||||
|
||||
, testCase "closeSocket on invalid handle returns error" $ do
|
||||
final <- runIOSource "main = io (closeSocket (pair \"sock\" 99999))"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "bindSocket and listen succeed on loopback port 0" $ do
|
||||
final <- runIOSource "main = io (bind socket (result : matchResult (err rest : pure result) (sock rest : bind (bindSocket sock \"127.0.0.1\" 0) (bindResult : matchResult (err rest : pure bindResult) (_ rest : bind (listen sock 1) (listenResult : pure listenResult)) bindResult)) result))"
|
||||
final @?= ioOkResult Leaf
|
||||
|
||||
, testCase "connect to non-listening port returns error" $ do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "main = io (bind socket (result :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure \"socket-err\")"
|
||||
, " (sock rest : connect sock \"127.0.0.1\" 1)"
|
||||
, " result))"
|
||||
]
|
||||
case final of
|
||||
Fork Leaf (Fork _ Leaf) -> return ()
|
||||
other -> assertFailure $ "Expected error result, got: " ++ show other
|
||||
|
||||
, testCase "accept and recv receive bytes from forked client" $
|
||||
withFreePort $ \port -> do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "client = port :"
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (sock rest :"
|
||||
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
|
||||
, " preserveResult connectResult (_ rest :"
|
||||
, " send sock [104 105]))))"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " bind (listen server 1) (listenResult :"
|
||||
, " preserveResult listenResult (_ rest :"
|
||||
, " bind (fork (client " ++ show port ++ ")) (_ :"
|
||||
, " bind (accept server) (acceptResult :"
|
||||
, " preserveResult acceptResult (accepted rest :"
|
||||
, " matchPair"
|
||||
, " (clientSock addr : recv clientSock 2)"
|
||||
, " accepted))))))))))"
|
||||
]
|
||||
final @?= ioOkResult (ofBytes (BS.pack [104, 105]))
|
||||
|
||||
, testCase "client recv receives server response via accepted socket" $
|
||||
withFreePort $ \port -> do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "serverTask = server :"
|
||||
, " bind (accept server) (acceptResult :"
|
||||
, " preserveResult acceptResult (accepted rest :"
|
||||
, " matchPair"
|
||||
, " (clientSock addr :"
|
||||
, " bind (recv clientSock 4) (msgResult :"
|
||||
, " preserveResult msgResult (_ rest :"
|
||||
, " send clientSock [112 111 110 103])))"
|
||||
, " accepted))"
|
||||
, ""
|
||||
, "clientTask = port :"
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (sock rest :"
|
||||
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
|
||||
, " preserveResult connectResult (_ rest :"
|
||||
, " bind (send sock [112 105 110 103]) (_ :"
|
||||
, " recv sock 4)))))"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " bind (listen server 1) (listenResult :"
|
||||
, " preserveResult listenResult (_ rest :"
|
||||
, " bind (fork (serverTask server)) (_ :"
|
||||
, " clientTask " ++ show port ++ "))))))))"
|
||||
]
|
||||
final @?= ioOkResult (ofBytes (BS.pack [112, 111, 110, 103]))
|
||||
|
||||
, testCase "recv on closed peer returns connection closed" $
|
||||
withFreePort $ \port -> do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "clientTask = port :"
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (sock rest :"
|
||||
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
|
||||
, " preserveResult connectResult (_ rest :"
|
||||
, " closeSocket sock))))"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " bind (listen server 1) (listenResult :"
|
||||
, " preserveResult listenResult (_ rest :"
|
||||
, " bind (fork (clientTask " ++ show port ++ ")) (_ :"
|
||||
, " bind (accept server) (acceptResult :"
|
||||
, " preserveResult acceptResult (accepted rest :"
|
||||
, " matchPair"
|
||||
, " (clientSock addr :"
|
||||
, " bind (yield) (_ :"
|
||||
, " recv clientSock 1))"
|
||||
, " accepted))))))))))"
|
||||
]
|
||||
final @?= ioErrResult "connection closed"
|
||||
|
||||
, testCase "accept invalid socket handle returns error" $ do
|
||||
final <- runIOSource "main = io (accept (pair \"sock\" 99999))"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "recv invalid socket handle returns error" $ do
|
||||
final <- runIOSource "main = io (recv (pair \"sock\" 99999) 1)"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "send invalid socket handle returns error" $ do
|
||||
final <- runIOSource "main = io (send (pair \"sock\" 99999) [(1)])"
|
||||
final @?= ioErrResult "invalid socket handle"
|
||||
|
||||
, testCase "getSocketName returns positive port after bind 0" $ do
|
||||
final <- runIOSource $
|
||||
unlines
|
||||
[ "preserveResult = (result okCase :"
|
||||
, " matchResult"
|
||||
, " (err rest : pure result)"
|
||||
, " okCase"
|
||||
, " result)"
|
||||
, ""
|
||||
, "main = io ("
|
||||
, " bind socket (result :"
|
||||
, " preserveResult result (server rest :"
|
||||
, " bind (bindSocket server \"127.0.0.1\" 0) (bindResult :"
|
||||
, " preserveResult bindResult (_ rest :"
|
||||
, " getSocketName server)))))"
|
||||
]
|
||||
case final of
|
||||
Fork (Stem Leaf) (Fork val Leaf) ->
|
||||
case toNumber val of
|
||||
Right port | port > 0 -> return ()
|
||||
Right 0 -> assertFailure "Expected positive port, got 0"
|
||||
Left _ -> assertFailure $ "Expected numeric port, got: " ++ show val
|
||||
other -> assertFailure $ "Expected ok result, got: " ++ show other
|
||||
]
|
||||
]
|
||||
|
||||
withFreePort :: (Int -> IO a) -> IO a
|
||||
withFreePort action =
|
||||
bracket
|
||||
(NS.socket NS.AF_INET NS.Stream NS.defaultProtocol)
|
||||
NS.close
|
||||
(\s -> do
|
||||
NS.setSocketOption s NS.ReuseAddr 1
|
||||
NS.bind s (NS.SockAddrInet 0 (NS.tupleToHostAddress (127, 0, 0, 1)))
|
||||
port <- NS.socketPort s
|
||||
action (fromIntegral port))
|
||||
|
||||
runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T)
|
||||
runIOSourceWith perms readerEnv initialState source = do
|
||||
ioEnv <- evaluateFile "./lib/io.tri"
|
||||
evalEnv <- evalTricuWithStore Nothing ioEnv (parseTricu source)
|
||||
sockEnv <- evaluateFile "./lib/socket.tri"
|
||||
let combinedEnv = Map.union sockEnv ioEnv
|
||||
evalEnv <- evalTricuWithStore Nothing combinedEnv (parseTricu source)
|
||||
let fullTree = mainResult evalEnv
|
||||
result <- runIOWith perms readerEnv initialState fullTree
|
||||
case result of
|
||||
@@ -1942,5 +2127,5 @@ runIOSourceWithEnv perms readerEnv source = fmap fst $ runIOSourceWith perms rea
|
||||
ioOkResult :: T -> T
|
||||
ioOkResult val = Fork (Stem Leaf) (Fork val Leaf)
|
||||
|
||||
ioErrResult :: Integer -> T
|
||||
ioErrResult code = Fork Leaf (Fork (ofNumber code) Leaf)
|
||||
ioErrResult :: String -> T
|
||||
ioErrResult msg = Fork Leaf (Fork (ofString msg) Leaf)
|
||||
|
||||
@@ -52,6 +52,7 @@ executable tricu
|
||||
, megaparsec
|
||||
, memory
|
||||
, mtl
|
||||
, network
|
||||
, servant
|
||||
, sqlite-simple
|
||||
, stm
|
||||
@@ -102,6 +103,7 @@ benchmark tricu-bench
|
||||
, megaparsec
|
||||
, memory
|
||||
, mtl
|
||||
, network
|
||||
, sqlite-simple
|
||||
, text
|
||||
, time
|
||||
@@ -148,6 +150,7 @@ test-suite tricu-tests
|
||||
, megaparsec
|
||||
, memory
|
||||
, mtl
|
||||
, network
|
||||
, servant
|
||||
, sqlite-simple
|
||||
, stm
|
||||
|
||||
Reference in New Issue
Block a user