2 Commits

Author SHA1 Message Date
020fa769a9 Event loop! 2026-05-19 17:00:36 -05:00
2e13583de3 Strings for IO driver errors 2026-05-18 19:12:42 -05:00
9 changed files with 691 additions and 113 deletions

View File

@@ -26,7 +26,7 @@
--
-- File operations return a Result tree (see lib/base.tri):
-- ok value -- pair true (pair value t)
-- err code -- pair false (pair code t)
-- err msg -- pair false (pair msg t)
--
-- Use onReadFile / onWriteFile for convenient branching.
--

View File

@@ -0,0 +1,46 @@
!import "../../lib/base.tri" !Local
!import "../../lib/io.tri" !Local
!import "../../lib/socket.tri" !Local
-- Preserve the host-driver Result shape on error, run okCase on success.
onOk = action okCase :
bind action (result :
matchResult
(err rest : pure result)
okCase
result)
-- Convenience: print a string and continue.
printLn = s : bind (putStr (append s "\n")) (_ : pure t)
-- Main accept+echo loop. Recursion via y.
echoLoop = y (self server :
bind (accept server) (acceptResult :
matchResult
(err rest :
bind (printLn (append "accept error: " err)) (_ :
self server))
(accepted rest :
matchPair
(clientSock addr :
bind (printLn (append "client from " addr)) (_ :
bind (recv clientSock 4096) (msgResult :
matchResult
(err rest :
bind (closeSocket clientSock) (_ :
self server))
(msg rest :
bind (send clientSock msg) (_ :
bind (closeSocket clientSock) (_ :
self server)))
msgResult)))
accepted)
acceptResult))
main = io (
onOk socket (server rest :
onOk (bindSocket server "127.0.0.1" 0) (_ rest :
onOk (listen server 5) (_ rest :
onOk (getSocketName server) (port rest :
bind (printLn (append "Echo server listening on port " (showNumber port))) (_ :
echoLoop server))))))

View File

@@ -74,7 +74,7 @@ succ = y (self :
t))
ok = value rest : pair true (pair value rest)
err = code rest : pair false (pair code rest)
err = msg rest : pair false (pair msg rest)
matchResult = (errCase okCase result :
matchPair

View File

@@ -63,24 +63,24 @@ onWriteFile = (path contents errCase okCase :
readFileOrPrintError = (path okCase :
onReadFile path
(err rest : putStrLn "Read failed")
(err rest : putStrLn (append "Read failed: " err))
okCase)
writeFileOrPrintError = (path contents okCase :
onWriteFile path contents
(err rest : putStrLn "Write failed")
(err rest : putStrLn (append "Write failed: " err))
okCase)
copyFile = (src dst :
bind (readFile src)
(result :
matchResult
(err rest : putStrLn "Read failed")
(err rest : putStrLn (append "Read failed: " err))
(contents rest :
bind (writeFile dst contents)
(wr :
matchResult
(err rest : putStrLn "Write failed")
(err rest : putStrLn (append "Write failed: " err))
(ok rest : pure t)
wr))
result))

63
lib/socket.tri Normal file
View File

@@ -0,0 +1,63 @@
!import "base.tri" !Local
!import "io.tri" !Local
-- Socket primitives for the IO driver.
-- All actions return a Result tree (see lib/base.tri):
-- ok value -- pair true (pair value t)
-- err msg -- pair false (pair msg t)
socket = pair 70 t
closeSocket = sock : pair 71 sock
bindSocket = sock addr port : pair 72 (pair sock (pair addr port))
listen = sock backlog : pair 73 (pair sock backlog)
accept = sock : pair 74 sock
connect = sock addr port : pair 75 (pair sock (pair addr port))
recv = sock maxBytes : pair 76 (pair sock maxBytes)
send = sock bytes : pair 77 (pair sock bytes)
getSocketName = sock : pair 78 sock
-- ---------------------------------------------------------------------------
-- Convenience helpers
-- ---------------------------------------------------------------------------
onSocket = (action errCase okCase :
bind action (result :
matchResult errCase okCase result))
-- Create a listening socket bound to an address and port.
-- Returns ok listenSocket or err message.
listenSocket = addr port backlog :
bind (socket) (result :
matchResult
(err rest : pure (err "socket creation failed"))
(sock rest :
bind (bindSocket sock addr port) (bindResult :
matchResult
(err rest : pure (err "bind failed"))
(_ rest :
bind (listen sock backlog) (listenResult :
matchResult
(err rest : pure (err "listen failed"))
(_ rest : pure (ok sock))
listenResult))
bindResult))
result)
-- Accept a connection and return (clientSocket, peerAddr).
-- The returned peerAddr is a string like "127.0.0.1:8080".
onAccept = (sock errCase okCase :
bind (accept sock) (result :
matchResult errCase okCase result))
-- Receive all available bytes up to maxBytes.
onRecv = (sock maxBytes errCase okCase :
bind (recv sock maxBytes) (result :
matchResult errCase okCase result))
-- Send bytes and return number of bytes sent.
onSend = (sock bytes errCase okCase :
bind (send sock bytes) (result :
matchResult errCase okCase result))
-- Close a socket, ignoring errors.
closeSocket_ = sock : bind (closeSocket sock) (_ : pure t)

View File

@@ -33,16 +33,15 @@ type Uses = [Bool]
evalSingle :: Env -> TricuAST -> Env
evalSingle env term
| SDef name [] body <- term
= case Map.lookup name env of
Just existingValue
| existingValue == evalASTSync env body -> env
| otherwise
-> let res = evalASTSync env body
in Map.insert "!result" res (Map.insert name res env)
Nothing
-> let res = evalASTSync env body
in Map.insert "!result" res (Map.insert name res env)
| SDef name params body <- term
= let res = evalASTSync env (if null params then body else SLambda params body)
in case Map.lookup name env of
Just existingValue
| existingValue == res -> env
| otherwise
-> Map.insert "!result" res (Map.insert name res env)
Nothing
-> Map.insert "!result" res (Map.insert name res env)
| SApp func arg <- term
= let res = apply (evalASTSync env func) (evalASTSync env arg)
in Map.insert "!result" res env
@@ -57,7 +56,7 @@ evalSingle env term
in Map.insert "!result" res env
evalTricu :: Env -> [TricuAST] -> Env
evalTricu env x = go env (reorderDefs env x)
evalTricu env x = go env (reorderDefs env (map recoverParams x))
where
go env' [] = env'
go env' [def] =
@@ -102,11 +101,10 @@ evalASTWithEnv mconn localEnv ast = do
let combinedEnv = Map.union localEnv storeEnv
return $ evalASTSync combinedEnv ast
-- | Store-aware version of 'evalSingle'.
evalSingleWithStore :: Maybe Connection -> Env -> TricuAST -> IO Env
evalSingleWithStore mconn env term
| SDef name [] body <- term = do
res <- evalASTWithEnv mconn env body
| SDef name params body <- term = do
res <- evalASTWithEnv mconn env (if null params then body else SLambda params body)
case Map.lookup name env of
Just existingValue
| existingValue == res -> return env
@@ -116,11 +114,8 @@ evalSingleWithStore mconn env term
res <- evalASTWithEnv mconn env term
return $ Map.insert "!result" res env
-- | Store-aware version of 'evalTricu'. Does not preload the entire
-- content store; terms are resolved on demand as variables are
-- encountered.
evalTricuWithStore :: Maybe Connection -> Env -> [TricuAST] -> IO Env
evalTricuWithStore mconn env x = go env (reorderDefs env x)
evalTricuWithStore mconn env x = go env (reorderDefs env (map recoverParams x))
where
go env' [] = return env'
go env' [def] = do
@@ -130,6 +125,10 @@ evalTricuWithStore mconn env x = go env (reorderDefs env x)
updatedEnv <- evalSingleWithStore mconn env' def
evalTricuWithStore mconn updatedEnv xs
recoverParams :: TricuAST -> TricuAST
recoverParams (SDef name [] (SLambda params body)) = SDef name params body
recoverParams term = term
collectVarNames :: TricuAST -> [(String, Maybe String)]
collectVarNames = go []
where
@@ -189,6 +188,7 @@ elimLambda = go
| isSList term = slistTransform term
| otherwise = term
etaReduction (SLambda [v] (SVar x Nothing)) = v == x
etaReduction (SLambda [v] (SApp f (SVar x Nothing))) = v == x && not (usesBinder v f)
etaReduction _ = False
@@ -209,8 +209,9 @@ elimLambda = go
application (SApp _ _) = True
application _ = False
etaReduceResult (SLambda [_] (SVar _ Nothing)) = _I
etaReduceResult (SLambda [_] (SApp f _)) = f
etaReduceResult _ = error "etaReduceResult: expected SLambda [v] (SApp f _)"
etaReduceResult _ = error "etaReduceResult: unexpected shape"
lambdaListResult (SLambda [v] (SList xs)) =
SLambda [v] (foldr wrapTLeaf TLeaf xs)
@@ -254,12 +255,12 @@ composeBody f g x = SApp (SVar f Nothing) (SApp (SVar g Nothing) (SVar x Nothin
isFree :: String -> TricuAST -> Bool
isFree x t = Set.member x (freeVars t)
-- Keep old freeVars for compatibility with reorderDefs which still uses TricuAST
freeVars :: TricuAST -> Set String
freeVars (SVar v Nothing) = Set.singleton v
freeVars (SVar v (Just _)) = Set.singleton v
freeVars (SApp t u) = Set.union (freeVars t) (freeVars u)
freeVars (SLambda vs body) = Set.difference (freeVars body) (Set.fromList vs)
freeVars (SDef _ params body) = Set.difference (freeVars body) (Set.fromList params)
freeVars (TStem t) = freeVars t
freeVars (TFork t u) = Set.union (freeVars t) (freeVars u)
freeVars (SList xs) = foldMap freeVars xs
@@ -275,7 +276,7 @@ reorderDefs env defs
(defsOnly, others) = partition isDef defs
defNames = [ name | SDef name _ _ <- defsOnly ]
defsWithFreeVars = [(def, freeVars body) | def@(SDef _ _ body) <- defsOnly]
defsWithFreeVars = [(def, freeVars def) | def <- defsOnly]
graph = buildDepGraph defsOnly
sortedDefs = sortDeps graph
@@ -298,8 +299,8 @@ buildDepGraph topDefs
"Conflicting definitions detected: " ++ show conflictingDefs
| otherwise =
Map.fromList
[ (name, depends topDefs (SDef name [] body))
| SDef name _ body <- topDefs]
[ (name, depends topDefs def)
| def@(SDef name _ _) <- topDefs]
where
defsMap = Map.fromListWith (++)
[(name, [(name, body)]) | SDef name _ body <- topDefs]
@@ -329,10 +330,10 @@ sortDeps graph = go [] Set.empty (Map.keys graph)
notReady
depends :: [TricuAST] -> TricuAST -> Set.Set String
depends topDefs (SDef _ _ body) =
depends topDefs def@(SDef _ _ _) =
Set.intersection
(Set.fromList [n | SDef n _ _ <- topDefs])
(freeVars body)
(freeVars def)
depends _ _ = Set.empty
result :: Env -> T

View File

@@ -12,7 +12,7 @@ import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes,
import qualified Data.ByteString as BS
import System.IO (putStr, getLine)
import qualified System.IO as IO
import Control.Exception (try, IOException, SomeException)
import Control.Exception (try, catch, IOException, SomeException)
import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError)
import Data.List (isPrefixOf)
import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories)
@@ -27,6 +27,8 @@ import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar,
import qualified Data.Set as Set
import Data.Set (Set)
import qualified Data.Foldable as Fold
import qualified Network.Socket as NS
import qualified Network.Socket.ByteString as NSB
-- ---------------------------------------------------------------------------
-- Permissions
@@ -95,67 +97,36 @@ data Machine = Machine
--
-- Runtime protocol errors are returned as direct values via errResult.
-- Error code ranges:
-- 1-19 host IO / filesystem errors
-- 20-39 policy / permission errors
-- 40-59 protocol / decode / type errors
-- 60-79 async errors
-- 80-99 scheduler / runtime errors
-- Host IO / filesystem errors (1-19)
errDoesNotExist, errPermission, errAlreadyExists, errIOOther :: Integer
errDoesNotExist = 1
errPermission = 2
errAlreadyExists = 3
errIOOther = 4
-- Policy / permission errors (20-39)
errPolicyDeny :: Integer
errPolicyDeny = 20
-- Protocol / decode / type errors (40-59)
errInvalidAction, errInvalidString :: Integer
errInvalidAction = 40
errInvalidString = 41
-- Async errors (60-79)
errInvalidHandle, errSelfAwait, errInvalidSleep, errCyclicAwait :: Integer
errInvalidHandle = 60
errSelfAwait = 61
errInvalidSleep = 62
errCyclicAwait = 63
-- Scheduler / runtime errors (80-99)
errDeadlock :: Integer
errDeadlock = 80
ioErrorCode :: IOException -> Integer
ioErrorCode e
| isDoesNotExistError e = errDoesNotExist
| isPermissionError e = errPermission
| isAlreadyExistsError e = errAlreadyExists
| otherwise = errIOOther
okResult :: T -> T
okResult val = Fork (Stem Leaf) (Fork val Leaf)
errResult :: Integer -> T
errResult code = Fork Leaf (Fork (ofNumber code) Leaf)
errResult :: String -> T
errResult msg = Fork Leaf (Fork (ofString msg) Leaf)
pureAction :: T -> T
pureAction x = Fork (ofNumber 0) x
invalidAsyncHandleResult :: T
invalidAsyncHandleResult = errResult errInvalidHandle
invalidAsyncHandleResult = errResult "invalid task handle"
invalidSocketHandleResult :: T
invalidSocketHandleResult = errResult "invalid socket handle"
selfAwaitResult :: T
selfAwaitResult = errResult errSelfAwait
selfAwaitResult = errResult "self await"
deadlockResult :: T
deadlockResult = errResult errDeadlock
deadlockResult = errResult "deadlock"
invalidSleepResult :: T
invalidSleepResult = errResult errInvalidSleep
invalidSleepResult = errResult "invalid sleep"
ioErrorString :: IOException -> String
ioErrorString e
| isDoesNotExistError e = "does not exist"
| isPermissionError e = "permission denied"
| isAlreadyExistsError e = "already exists"
| otherwise = "io error"
-- ---------------------------------------------------------------------------
-- Task identity and handles
@@ -179,6 +150,45 @@ decodeTaskHandle tree =
_ ->
Left "invalid task handle"
-- ---------------------------------------------------------------------------
-- Socket identity and handles
-- ---------------------------------------------------------------------------
newtype SockId = SockId Integer
deriving (Eq, Ord, Show)
sockHandle :: SockId -> T
sockHandle (SockId n) =
Fork (ofString "sock") (ofNumber n)
decodeSockHandle :: T -> Either String SockId
decodeSockHandle tree =
case tree of
Fork tag nTree -> do
tagString <- toString tag
if tagString == "sock"
then SockId <$> toNumber nTree
else Left "invalid socket handle tag"
_ ->
Left "invalid socket handle"
getSocketPort :: NS.Socket -> IO (Maybe Integer)
getSocketPort sock = do
addr <- NS.getSocketName sock
case addr of
NS.SockAddrInet p _ -> return (Just (fromIntegral p))
NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p))
_ -> return Nothing
-- ---------------------------------------------------------------------------
-- Socket registry
-- ---------------------------------------------------------------------------
data SocketRegistry = SocketRegistry
{ sockMap :: Map SockId NS.Socket
, sockNextId :: Integer
}
-- ---------------------------------------------------------------------------
-- Free-monad action AST
-- ---------------------------------------------------------------------------
@@ -200,6 +210,15 @@ data Action
| AAwait T
| AYield
| ASleep T
| ASocket
| ACloseSocket T
| ABindSocket T T T
| AListen T T
| AAccept T
| AConnect T T T
| ARecv T T
| ASend T T
| AGetSocketName T
deriving (Show)
-- ---------------------------------------------------------------------------
@@ -234,6 +253,19 @@ tagAwait = 61
tagYield = 62
tagSleep = 63
tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer
tagSocket = 70
tagCloseSocket = 71
tagBindSocket = 72
tagListen = 73
tagAccept = 74
tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer
tagConnect = 75
tagRecv = 76
tagSend = 77
tagGetSocketName = 78
data Step
= Halt Runtime T
| Continue Machine
@@ -313,6 +345,43 @@ decodeAction tree =
Right n | n == tagSleep ->
Right (ASleep payload)
Right n | n == tagSocket ->
Right ASocket
Right n | n == tagCloseSocket ->
Right (ACloseSocket payload)
Right n | n == tagBindSocket ->
case payload of
Fork sock (Fork addr port) -> Right (ABindSocket sock addr port)
_ -> Left "Invalid BindSocket: expected pair sock (pair addr port)"
Right n | n == tagListen ->
case payload of
Fork sock backlog -> Right (AListen sock backlog)
_ -> Left "Invalid Listen: expected pair sock backlog"
Right n | n == tagAccept ->
Right (AAccept payload)
Right n | n == tagConnect ->
case payload of
Fork sock (Fork addr port) -> Right (AConnect sock addr port)
_ -> Left "Invalid Connect: expected pair sock (pair addr port)"
Right n | n == tagRecv ->
case payload of
Fork sock maxBytes -> Right (ARecv sock maxBytes)
_ -> Left "Invalid Recv: expected pair sock maxBytes"
Right n | n == tagSend ->
case payload of
Fork sock bytes -> Right (ASend sock bytes)
_ -> Left "Invalid Send: expected pair sock bytes"
Right n | n == tagGetSocketName ->
Right (AGetSocketName payload)
Right n ->
Left $ "Unknown IO action tag: " ++ show n
@@ -346,11 +415,11 @@ finishValue machine value =
, machineFrames = rest
})
stepMachine :: Machine -> IO Step
stepMachine machine =
stepMachine :: TVar SocketRegistry -> Machine -> IO Step
stepMachine sockVar machine =
case decodeAction (machineCurrent machine) of
Right action -> dispatch action
Left _ -> finishValue machine (errResult errInvalidAction)
Left _ -> finishValue machine (errResult "invalid action")
where
dispatch action = case action of
APure val ->
@@ -367,14 +436,14 @@ stepMachine machine =
Right s ->
pure (AsyncAction (putStr s >> pure Leaf) machine)
Left _ ->
finishValue machine (errResult errInvalidString)
finishValue machine (errResult "invalid string")
APutBytes bs ->
case decodeBytes bs "PutBytes" of
Right b ->
pure (AsyncAction (BS.putStr b >> pure Leaf) machine)
Left _ ->
finishValue machine (errResult errInvalidString)
finishValue machine (errResult "invalid bytes")
AGetLine ->
pure (AsyncAction (ofString <$> getLine) machine)
@@ -386,7 +455,7 @@ stepMachine machine =
case mDeny of
Just denied -> finishValue machine denied
Nothing -> pure (AsyncAction (tryReadFile p) machine)
Left _ -> finishValue machine (errResult errInvalidString)
Left _ -> finishValue machine (errResult "invalid string")
AWriteFile path contents ->
case decodeString path "WriteFile" of
@@ -397,8 +466,8 @@ stepMachine machine =
case mDeny of
Just denied -> finishValue machine denied
Nothing -> pure (AsyncAction (tryWriteFile p c) machine)
Left _ -> finishValue machine (errResult errInvalidString)
Left _ -> finishValue machine (errResult errInvalidString)
Left _ -> finishValue machine (errResult "invalid string")
Left _ -> finishValue machine (errResult "invalid string")
AWriteBytes path contents ->
case decodeString path "WriteBytes" of
@@ -409,8 +478,8 @@ stepMachine machine =
case mDeny of
Just denied -> finishValue machine denied
Nothing -> pure (AsyncAction (tryWriteFileBytes p c) machine)
Left _ -> finishValue machine (errResult errInvalidString)
Left _ -> finishValue machine (errResult errInvalidString)
Left _ -> finishValue machine (errResult "invalid bytes")
Left _ -> finishValue machine (errResult "invalid string")
AAsk ->
finishValue machine (rtEnv (machineRuntime machine))
@@ -453,6 +522,207 @@ stepMachine machine =
_ ->
finishValue machine invalidSleepResult
ASocket -> do
result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket)
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right sock -> do
NS.setSocketOption sock NS.ReuseAddr 1
sid <- atomically $ do
SocketRegistry m next <- readTVar sockVar
let sid = SockId next
writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1))
return sid
finishValue machine (okResult (sockHandle sid))
ACloseSocket sockTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid -> do
mSock <- atomically $ do
SocketRegistry m next <- readTVar sockVar
case Map.lookup sid m of
Nothing -> return Nothing
Just sock -> do
writeTVar sockVar (SocketRegistry (Map.delete sid m) next)
return (Just sock)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
NS.close sock
finishValue machine (okResult Leaf)
ABindSocket sockTree addrTree portTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeString addrTree "BindSocket" of
Left _ -> finishValue machine (errResult "invalid address")
Right addrStr ->
case toNumber portTree of
Left _ -> finishValue machine (errResult "invalid port")
Right port -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
result <- try (do
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
(Just addrStr)
(Just (show port))
let serverAddr = head addrInfo
NS.bind sock (NS.addrAddress serverAddr)
) :: IO (Either SomeException ())
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right () ->
finishValue machine (okResult Leaf)
AListen sockTree backlogTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case toNumber backlogTree of
Left _ -> finishValue machine (errResult "invalid backlog")
Right backlog -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ())
case result of
Left e ->
finishValue machine (errResult ("io error: " ++ show e))
Right () ->
finishValue machine (okResult Leaf)
AAccept listenTree ->
case decodeSockHandle listenTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right listenSid ->
pure (AsyncAction (do
mListenSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup listenSid m)
case mListenSock of
Nothing -> return (errResult "invalid socket handle")
Just listenSock -> do
result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr))
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right (clientSock, addr) -> do
clientSid <- atomically $ do
SocketRegistry m next <- readTVar sockVar
let sid = SockId next
writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1))
return sid
let addrStr = case addr of
NS.SockAddrInet p h ->
let (a,b,c,d) = NS.hostAddressToTuple h
in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p
_ -> show addr
return (okResult (Fork (sockHandle clientSid) (ofString addrStr)))
) machine)
AConnect sockTree addrTree portTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeString addrTree "Connect" of
Left _ -> finishValue machine (errResult "invalid address")
Right addrStr ->
case toNumber portTree of
Left _ -> finishValue machine (errResult "invalid port")
Right port -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (do
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
(Just addrStr)
(Just (show port))
let serverAddr = head addrInfo
NS.connect sock (NS.addrAddress serverAddr)
) :: IO (Either SomeException ())
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right () ->
return (okResult Leaf)
) machine)
ARecv sockTree maxBytesTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case toNumber maxBytesTree of
Left _ -> finishValue machine (errResult "invalid maxBytes")
Right maxBytes -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString)
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right bs ->
if BS.null bs
then return (errResult "connection closed")
else return (okResult (ofBytes bs))
) machine)
AGetSocketName sockTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock -> do
mPort <- getSocketPort sock
case mPort of
Just port -> finishValue machine (okResult (ofNumber port))
Nothing -> finishValue machine (errResult "io error: could not get socket name")
ASend sockTree bytesTree ->
case decodeSockHandle sockTree of
Left _ -> finishValue machine invalidSocketHandleResult
Right sid ->
case decodeBytes bytesTree "Send" of
Left _ -> finishValue machine (errResult "invalid bytes")
Right bs -> do
mSock <- atomically $ do
SocketRegistry m _ <- readTVar sockVar
return (Map.lookup sid m)
case mSock of
Nothing -> finishValue machine invalidSocketHandleResult
Just sock ->
pure (AsyncAction (do
result <- try (NSB.send sock bs) :: IO (Either SomeException Int)
case result of
Left e ->
return (errResult ("io error: " ++ show e))
Right sent ->
return (okResult (ofNumber (fromIntegral sent)))
) machine)
-- Permission and IO helpers
checkReadPerm p =
if allowReadAll (rtPerms (machineRuntime machine))
@@ -480,7 +750,7 @@ stepMachine machine =
then return Nothing
else return $ Just policyErrResult
policyErrResult = errResult errPolicyDeny
policyErrResult = errResult "permission denied"
canonicalizeSafe :: FilePath -> IO (Either String FilePath)
canonicalizeSafe p = do
@@ -534,19 +804,19 @@ stepMachine machine =
result <- try (BS.readFile path) :: IO (Either IOException BS.ByteString)
case result of
Right content -> return $ okResult (ofBytes content)
Left e -> return $ errResult (ioErrorCode e)
Left e -> return $ errResult (ioErrorString e)
tryWriteFile path contents = do
result <- try (IO.writeFile path contents) :: IO (Either IOException ())
case result of
Right () -> return $ okResult Leaf
Left e -> return $ errResult (ioErrorCode e)
Left e -> return $ errResult (ioErrorString e)
tryWriteFileBytes path contents = do
result <- try (BS.writeFile path contents) :: IO (Either IOException ())
case result of
Right () -> return $ okResult Leaf
Left e -> return $ errResult (ioErrorCode e)
Left e -> return $ errResult (ioErrorString e)
decodeString t ctx =
case toString t of
@@ -577,6 +847,8 @@ data Scheduler = Scheduler
, schedulerSleepQueue :: Map UTCTime (Set TaskId)
, schedulerAsyncCompleted :: TVar (Map TaskId T)
, schedulerCompleted :: Map TaskId (T, T)
, schedulerSockets :: TVar SocketRegistry
, schedulerNextSockId :: Integer
}
instance Show Scheduler where
@@ -587,10 +859,12 @@ instance Show Scheduler where
++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s)
++ ", schedulerAsyncCompleted = <tvar>"
++ ", schedulerCompleted = " ++ show (schedulerCompleted s)
++ ", schedulerSockets = <tvar>"
++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s)
++ " }"
initialScheduler :: TVar (Map TaskId T) -> Machine -> Scheduler
initialScheduler asyncVar mainMachine =
initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler
initialScheduler asyncVar sockVar mainMachine =
Scheduler
{ schedulerNextTaskId = 1
, schedulerRunnable = Seq.singleton (TaskId 0)
@@ -599,6 +873,8 @@ initialScheduler asyncVar mainMachine =
, schedulerSleepQueue = Map.empty
, schedulerAsyncCompleted = asyncVar
, schedulerCompleted = Map.empty
, schedulerSockets = sockVar
, schedulerNextSockId = 0
}
runtimeOfStatus :: TaskStatus -> Maybe Runtime
@@ -732,7 +1008,7 @@ handleStep currentId (AwaitRequested targetId machine) scheduler
Just (BlockedOn nextId _) ->
if wouldCycle targetId currentId (schedulerTasks scheduler)
then resumeCurrentWith currentId (errResult errCyclicAwait) machine scheduler
then resumeCurrentWith currentId (errResult "cyclic await") machine scheduler
else block
Just _ -> block
@@ -759,8 +1035,11 @@ handleStep taskId (SleepRequested ms machine) scheduler = do
handleStep taskId (AsyncAction ioAction machine) scheduler = do
_ <- forkIO $ do
result <- ioAction
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId result)
result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e))
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $
case result of
Right val -> val
Left msg -> errResult msg)
pure scheduler
{ schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler)
}
@@ -818,7 +1097,7 @@ schedulerStep scheduler = do
taskId :< restQueue ->
case Map.lookup taskId (schedulerTasks scheduler1) of
Just (Runnable machine) -> do
step <- stepMachine machine
step <- stepMachine (schedulerSockets scheduler1) machine
handleStep taskId step scheduler1 { schedulerRunnable = restQueue }
_ ->
@@ -843,6 +1122,7 @@ runIOWith perms env initialState action =
Left err -> pure (Left err)
Right (_, action') -> do
asyncVar <- newTVarIO Map.empty
sockVar <- newTVarIO (SocketRegistry Map.empty 0)
let initialMachine = Machine
{ machineRuntime = Runtime
{ rtPerms = perms
@@ -852,7 +1132,7 @@ runIOWith perms env initialState action =
, machineCurrent = action'
, machineFrames = []
}
Right <$> runScheduler (initialScheduler asyncVar initialMachine)
Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine)
runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T)
runIOWithEnv perms env action = do

View File

@@ -10,7 +10,8 @@ import Wire
import ContentStore
import IODriver (IOPermissions(..), checkIOSentinel, runIO, runIOWithEnv, runIOWith, unsafePerms, defaultPerms)
import Control.Exception (evaluate, try, SomeException)
import Control.Exception (bracket, evaluate, try, SomeException)
import qualified Network.Socket as NS
import Control.Monad (forM_)
import Control.Monad.IO.Class (liftIO)
import System.IO.Temp (withSystemTempDirectory)
@@ -1553,15 +1554,15 @@ ioDriverTests = testGroup "IO driver tests"
-- Malformed action tests
, testCase "unknown IO action tag returns err result" $ do
final <- runIOSource "main = io (pair 99 t)"
final @?= ioErrResult 40
final @?= ioErrResult "invalid action"
, testCase "malformed Bind returns err result" $ do
final <- runIOSource "main = io (pair 1 t)"
final @?= ioErrResult 40
final @?= ioErrResult "invalid action"
, testCase "malformed ReadFile payload returns err result" $ do
final <- runIOSource "main = io (readFile (t t))"
final @?= ioErrResult 41
final @?= ioErrResult "invalid string"
-- Permission tests
, testCase "allowed read path succeeds" $
@@ -1586,7 +1587,7 @@ ioDriverTests = testGroup "IO driver tests"
unlines
[ "main = io (readFile \"" ++ deniedPath ++ "\")"
]
result @?= ioErrResult 20
result @?= ioErrResult "permission denied"
, testCase "writeFile denied path returns err result" $
withSystemTempDirectory "tricu-io-write-denied" $ \dir -> do
@@ -1597,7 +1598,7 @@ ioDriverTests = testGroup "IO driver tests"
unlines
[ "main = io (writeFile \"" ++ deniedPath ++ "\" \"x\")"
]
result @?= ioErrResult 20
result @?= ioErrResult "permission denied"
, testCase "path prefix does not allow prefix bypass" $
withSystemTempDirectory "tricu-io-prefix" $ \dir -> do
@@ -1611,7 +1612,7 @@ ioDriverTests = testGroup "IO driver tests"
unlines
[ "main = io (readFile \"" ++ bypassPath ++ "\")"
]
result @?= ioErrResult 20
result @?= ioErrResult "permission denied"
-- Pure test
, testCase "pure performs no effects" $ do
@@ -1820,14 +1821,14 @@ ioDriverTests = testGroup "IO driver tests"
unlines
[ "main = io (await (pair \"task\" 0))"
]
final @?= ioErrResult 61
final @?= ioErrResult "self await"
, testCase "await invalid handle returns async error" $ do
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf $
unlines
[ "main = io (await 123)"
]
final @?= ioErrResult 60
final @?= ioErrResult "invalid task handle"
, testCase "yield returns unit and resumes continuation" $ do
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf $
@@ -1890,7 +1891,7 @@ ioDriverTests = testGroup "IO driver tests"
[ "main = io (bind (fork (await (pair \"task\" 0))) (h :"
, " await h))"
]
final @?= ioErrResult 63
final @?= ioErrResult "cyclic await"
, testCase "writeBytes and readFile roundtrip binary data" $
withSystemTempDirectory "tricu-io-bytes" $ \dir -> do
@@ -1918,12 +1919,196 @@ ioDriverTests = testGroup "IO driver tests"
build k = "bind (fork (pure \"x\")) (h : bind (await h) (_ : " ++ build (k - 1) ++ "))"
(final, _) <- runIOSourceWith unsafePerms Leaf Leaf ("main = io (" ++ build n ++ ")")
final @?= ofString "done"
, testGroup "Socket primitives"
[ testCase "socket returns ok result with valid handle" $ do
final <- runIOSource "main = io socket"
final @?= ioOkResult (Fork (ofString "sock") (ofNumber 0))
, testCase "closeSocket on invalid handle returns error" $ do
final <- runIOSource "main = io (closeSocket (pair \"sock\" 99999))"
final @?= ioErrResult "invalid socket handle"
, testCase "bindSocket and listen succeed on loopback port 0" $ do
final <- runIOSource "main = io (bind socket (result : matchResult (err rest : pure result) (sock rest : bind (bindSocket sock \"127.0.0.1\" 0) (bindResult : matchResult (err rest : pure bindResult) (_ rest : bind (listen sock 1) (listenResult : pure listenResult)) bindResult)) result))"
final @?= ioOkResult Leaf
, testCase "connect to non-listening port returns error" $ do
final <- runIOSource $
unlines
[ "main = io (bind socket (result :"
, " matchResult"
, " (err rest : pure \"socket-err\")"
, " (sock rest : connect sock \"127.0.0.1\" 1)"
, " result))"
]
case final of
Fork Leaf (Fork _ Leaf) -> return ()
other -> assertFailure $ "Expected error result, got: " ++ show other
, testCase "accept and recv receive bytes from forked client" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "client = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " send sock [104 105]))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (client " ++ show port ++ ")) (_ :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr : recv clientSock 2)"
, " accepted))))))))))"
]
final @?= ioOkResult (ofBytes (BS.pack [104, 105]))
, testCase "client recv receives server response via accepted socket" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "serverTask = server :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr :"
, " bind (recv clientSock 4) (msgResult :"
, " preserveResult msgResult (_ rest :"
, " send clientSock [112 111 110 103])))"
, " accepted))"
, ""
, "clientTask = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " bind (send sock [112 105 110 103]) (_ :"
, " recv sock 4)))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (serverTask server)) (_ :"
, " clientTask " ++ show port ++ "))))))))"
]
final @?= ioOkResult (ofBytes (BS.pack [112, 111, 110, 103]))
, testCase "recv on closed peer returns connection closed" $
withFreePort $ \port -> do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "clientTask = port :"
, " bind socket (result :"
, " preserveResult result (sock rest :"
, " bind (connect sock \"127.0.0.1\" port) (connectResult :"
, " preserveResult connectResult (_ rest :"
, " closeSocket sock))))"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" " ++ show port ++ ") (bindResult :"
, " preserveResult bindResult (_ rest :"
, " bind (listen server 1) (listenResult :"
, " preserveResult listenResult (_ rest :"
, " bind (fork (clientTask " ++ show port ++ ")) (_ :"
, " bind (accept server) (acceptResult :"
, " preserveResult acceptResult (accepted rest :"
, " matchPair"
, " (clientSock addr :"
, " bind (yield) (_ :"
, " recv clientSock 1))"
, " accepted))))))))))"
]
final @?= ioErrResult "connection closed"
, testCase "accept invalid socket handle returns error" $ do
final <- runIOSource "main = io (accept (pair \"sock\" 99999))"
final @?= ioErrResult "invalid socket handle"
, testCase "recv invalid socket handle returns error" $ do
final <- runIOSource "main = io (recv (pair \"sock\" 99999) 1)"
final @?= ioErrResult "invalid socket handle"
, testCase "send invalid socket handle returns error" $ do
final <- runIOSource "main = io (send (pair \"sock\" 99999) [(1)])"
final @?= ioErrResult "invalid socket handle"
, testCase "getSocketName returns positive port after bind 0" $ do
final <- runIOSource $
unlines
[ "preserveResult = (result okCase :"
, " matchResult"
, " (err rest : pure result)"
, " okCase"
, " result)"
, ""
, "main = io ("
, " bind socket (result :"
, " preserveResult result (server rest :"
, " bind (bindSocket server \"127.0.0.1\" 0) (bindResult :"
, " preserveResult bindResult (_ rest :"
, " getSocketName server)))))"
]
case final of
Fork (Stem Leaf) (Fork val Leaf) ->
case toNumber val of
Right port | port > 0 -> return ()
Right 0 -> assertFailure "Expected positive port, got 0"
Left _ -> assertFailure $ "Expected numeric port, got: " ++ show val
other -> assertFailure $ "Expected ok result, got: " ++ show other
]
]
withFreePort :: (Int -> IO a) -> IO a
withFreePort action =
bracket
(NS.socket NS.AF_INET NS.Stream NS.defaultProtocol)
NS.close
(\s -> do
NS.setSocketOption s NS.ReuseAddr 1
NS.bind s (NS.SockAddrInet 0 (NS.tupleToHostAddress (127, 0, 0, 1)))
port <- NS.socketPort s
action (fromIntegral port))
runIOSourceWith :: IOPermissions -> T -> T -> String -> IO (T, T)
runIOSourceWith perms readerEnv initialState source = do
ioEnv <- evaluateFile "./lib/io.tri"
evalEnv <- evalTricuWithStore Nothing ioEnv (parseTricu source)
sockEnv <- evaluateFile "./lib/socket.tri"
let combinedEnv = Map.union sockEnv ioEnv
evalEnv <- evalTricuWithStore Nothing combinedEnv (parseTricu source)
let fullTree = mainResult evalEnv
result <- runIOWith perms readerEnv initialState fullTree
case result of
@@ -1942,5 +2127,5 @@ runIOSourceWithEnv perms readerEnv source = fmap fst $ runIOSourceWith perms rea
ioOkResult :: T -> T
ioOkResult val = Fork (Stem Leaf) (Fork val Leaf)
ioErrResult :: Integer -> T
ioErrResult code = Fork Leaf (Fork (ofNumber code) Leaf)
ioErrResult :: String -> T
ioErrResult msg = Fork Leaf (Fork (ofString msg) Leaf)

View File

@@ -52,6 +52,7 @@ executable tricu
, megaparsec
, memory
, mtl
, network
, servant
, sqlite-simple
, stm
@@ -102,6 +103,7 @@ benchmark tricu-bench
, megaparsec
, memory
, mtl
, network
, sqlite-simple
, text
, time
@@ -148,6 +150,7 @@ test-suite tricu-tests
, megaparsec
, memory
, mtl
, network
, servant
, sqlite-simple
, stm