Perhaps the first webserver in Tree Calculus? Sure, it's married to a Haskell IO runtime... but we're managing all of the actual webserver semantics in tricu! This includes a demo Arboricx application server that is capable of storing and serving bundles.
1353 lines
49 KiB
Haskell
1353 lines
49 KiB
Haskell
module IODriver
|
|
( IOPermissions(..)
|
|
, defaultPerms
|
|
, unsafePerms
|
|
, checkIOSentinel
|
|
, runIO
|
|
, runIOWithEnv
|
|
, runIOWith
|
|
) where
|
|
|
|
import Research (T(..), apply, toString, toNumber, ofString, ofNumber, ofBytes, toBytes, ofList)
|
|
import qualified Data.ByteString as BS
|
|
import System.IO (putStr, getLine)
|
|
import qualified System.IO as IO
|
|
import Control.Exception (try, catch, IOException, SomeException)
|
|
import System.IO.Error (isDoesNotExistError, isPermissionError, isAlreadyExistsError)
|
|
import Data.List (isPrefixOf, isInfixOf)
|
|
import System.FilePath (normalise, isRelative, (</>), addTrailingPathSeparator, splitDirectories, takeDirectory)
|
|
import System.Directory (canonicalizePath, doesPathExist, getCurrentDirectory, listDirectory, createDirectory, renameFile, removeFile, doesDirectoryExist)
|
|
import Data.Time.Clock.POSIX (getPOSIXTime)
|
|
import Crypto.Hash (hash, SHA256, Digest)
|
|
import Data.ByteArray (convert)
|
|
import Data.ByteString.Base16 (encode)
|
|
import Data.Text.Encoding (decodeUtf8)
|
|
import qualified Data.Text as T
|
|
import Data.Char (toLower)
|
|
import qualified Data.Map.Strict as Map
|
|
import Data.Map.Strict (Map)
|
|
import qualified Data.Sequence as Seq
|
|
import Data.Sequence (Seq, (|>), ViewL(..))
|
|
import Data.Time.Clock (UTCTime, getCurrentTime, addUTCTime, diffUTCTime)
|
|
import Control.Concurrent (threadDelay, forkIO)
|
|
import Control.Concurrent.STM (TVar, newTVarIO, atomically, readTVar, writeTVar, modifyTVar', retry)
|
|
import qualified Data.Set as Set
|
|
import Data.Set (Set)
|
|
import qualified Data.Foldable as Fold
|
|
import qualified Network.Socket as NS
|
|
import qualified Network.Socket.ByteString as NSB
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Permissions
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
data IOPermissions = IOPermissions
|
|
{ allowRead :: [FilePath]
|
|
, allowWrite :: [FilePath]
|
|
, allowReadAll :: Bool
|
|
, allowWriteAll :: Bool
|
|
}
|
|
deriving (Show)
|
|
|
|
defaultPerms :: IOPermissions
|
|
defaultPerms = IOPermissions [] [] False False
|
|
|
|
unsafePerms :: IOPermissions
|
|
unsafePerms = IOPermissions [] [] True True
|
|
|
|
checkIOSentinel :: T -> Either String (Integer, T)
|
|
checkIOSentinel tree =
|
|
case tree of
|
|
Fork sentinel (Fork version action) -> do
|
|
s <- toString sentinel
|
|
case s of
|
|
"tricuIO" -> do
|
|
v <- toNumber version
|
|
return (v, action)
|
|
_ -> Left "sentinel mismatch (expected \"tricuIO\")"
|
|
_ -> Left "root is not an IO sentinel pair"
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Runtime, Frames, and Machine
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
data Runtime = Runtime
|
|
{ rtPerms :: IOPermissions
|
|
, rtEnv :: T
|
|
, rtState :: T
|
|
}
|
|
deriving (Show)
|
|
|
|
data Frame
|
|
= BindFrame T
|
|
| LocalFrame T
|
|
deriving (Show)
|
|
|
|
data Machine = Machine
|
|
{ machineRuntime :: Runtime
|
|
, machineCurrent :: T
|
|
, machineFrames :: [Frame]
|
|
}
|
|
deriving (Show)
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Result convention
|
|
-- ---------------------------------------------------------------------------
|
|
-- Direct-return actions pass the raw value to the continuation:
|
|
-- pure, bind, putStr, getLine, ask, local, get, put,
|
|
-- fork, await, yield, sleep
|
|
--
|
|
-- Result-return actions wrap the outcome as an ok/err pair:
|
|
-- ok val = Fork (Stem Leaf) (Fork val Leaf) -- (t t) val t
|
|
-- err code = Fork Leaf (Fork code Leaf) -- t code t
|
|
-- readFile, writeFile
|
|
--
|
|
-- Runtime protocol errors are returned as direct values via errResult.
|
|
|
|
okResult :: T -> T
|
|
okResult val = Fork (Stem Leaf) (Fork val Leaf)
|
|
|
|
errResult :: String -> T
|
|
errResult msg = Fork Leaf (Fork (ofString msg) Leaf)
|
|
|
|
pureAction :: T -> T
|
|
pureAction x = Fork (ofNumber 0) x
|
|
|
|
invalidAsyncHandleResult :: T
|
|
invalidAsyncHandleResult = errResult "invalid task handle"
|
|
|
|
invalidSocketHandleResult :: T
|
|
invalidSocketHandleResult = errResult "invalid socket handle"
|
|
|
|
selfAwaitResult :: T
|
|
selfAwaitResult = errResult "self await"
|
|
|
|
deadlockResult :: T
|
|
deadlockResult = errResult "deadlock"
|
|
|
|
invalidSleepResult :: T
|
|
invalidSleepResult = errResult "invalid sleep"
|
|
|
|
ioErrorString :: IOException -> String
|
|
ioErrorString e
|
|
| isDoesNotExistError e = "does not exist"
|
|
| isPermissionError e = "permission denied"
|
|
| isAlreadyExistsError e = "already exists"
|
|
| otherwise = "io error"
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Task identity and handles
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
newtype TaskId = TaskId Integer
|
|
deriving (Eq, Ord, Show)
|
|
|
|
taskHandle :: TaskId -> T
|
|
taskHandle (TaskId n) =
|
|
Fork (ofString "task") (ofNumber n)
|
|
|
|
decodeTaskHandle :: T -> Either String TaskId
|
|
decodeTaskHandle tree =
|
|
case tree of
|
|
Fork tag nTree -> do
|
|
tagString <- toString tag
|
|
if tagString == "task"
|
|
then TaskId <$> toNumber nTree
|
|
else Left "invalid task handle tag"
|
|
_ ->
|
|
Left "invalid task handle"
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Socket identity and handles
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
newtype SockId = SockId Integer
|
|
deriving (Eq, Ord, Show)
|
|
|
|
sockHandle :: SockId -> T
|
|
sockHandle (SockId n) =
|
|
Fork (ofString "sock") (ofNumber n)
|
|
|
|
decodeSockHandle :: T -> Either String SockId
|
|
decodeSockHandle tree =
|
|
case tree of
|
|
Fork tag nTree -> do
|
|
tagString <- toString tag
|
|
if tagString == "sock"
|
|
then SockId <$> toNumber nTree
|
|
else Left "invalid socket handle tag"
|
|
_ ->
|
|
Left "invalid socket handle"
|
|
|
|
getSocketPort :: NS.Socket -> IO (Maybe Integer)
|
|
getSocketPort sock = do
|
|
addr <- NS.getSocketName sock
|
|
case addr of
|
|
NS.SockAddrInet p _ -> return (Just (fromIntegral p))
|
|
NS.SockAddrInet6 p _ _ _ -> return (Just (fromIntegral p))
|
|
_ -> return Nothing
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Socket registry
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
data SocketRegistry = SocketRegistry
|
|
{ sockMap :: Map SockId NS.Socket
|
|
, sockNextId :: Integer
|
|
}
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Free-monad action AST
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
data Action
|
|
= APure T
|
|
| ABind T T
|
|
| APutStr T
|
|
| APutBytes T
|
|
| AGetLine
|
|
| AReadFile T
|
|
| AWriteFile T T
|
|
| AWriteBytes T T
|
|
| AListDirectory T
|
|
| ARenameFile T T
|
|
| ACreateDirectory T
|
|
| ADeleteFile T
|
|
| AFileExists T
|
|
| ASha256Hex T
|
|
| ACurrentTime
|
|
| AAsk
|
|
| ALocal T T
|
|
| AGet
|
|
| APut T
|
|
| AFork T
|
|
| AAwait T
|
|
| AYield
|
|
| ASleep T
|
|
| ASocket
|
|
| ACloseSocket T
|
|
| ABindSocket T T T
|
|
| AListen T T
|
|
| AAccept T
|
|
| AConnect T T T
|
|
| ARecv T T
|
|
| ASend T T
|
|
| AGetSocketName T
|
|
deriving (Show)
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Action tag constants
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
tagPure, tagBind :: Integer
|
|
tagPure = 0
|
|
tagBind = 1
|
|
|
|
tagPutStr, tagPutBytes, tagGetLine :: Integer
|
|
tagPutStr = 10
|
|
tagPutBytes = 12
|
|
tagGetLine = 11
|
|
|
|
tagReadFile, tagWriteFile, tagWriteBytes :: Integer
|
|
tagReadFile = 20
|
|
tagWriteFile = 21
|
|
tagWriteBytes = 22
|
|
|
|
tagListDirectory, tagRenameFile, tagCreateDirectory, tagDeleteFile, tagFileExists :: Integer
|
|
tagListDirectory = 23
|
|
tagRenameFile = 24
|
|
tagCreateDirectory = 25
|
|
tagDeleteFile = 26
|
|
tagFileExists = 27
|
|
|
|
tagSha256Hex, tagCurrentTime :: Integer
|
|
tagSha256Hex = 28
|
|
tagCurrentTime = 29
|
|
|
|
tagAsk, tagLocal :: Integer
|
|
tagAsk = 30
|
|
tagLocal = 31
|
|
|
|
tagGet, tagPut :: Integer
|
|
tagGet = 40
|
|
tagPut = 41
|
|
|
|
tagFork, tagAwait, tagYield, tagSleep :: Integer
|
|
tagFork = 60
|
|
tagAwait = 61
|
|
tagYield = 62
|
|
tagSleep = 63
|
|
|
|
tagSocket, tagCloseSocket, tagBindSocket, tagListen, tagAccept :: Integer
|
|
tagSocket = 70
|
|
tagCloseSocket = 71
|
|
tagBindSocket = 72
|
|
tagListen = 73
|
|
tagAccept = 74
|
|
|
|
tagConnect, tagRecv, tagSend, tagGetSocketName :: Integer
|
|
tagConnect = 75
|
|
tagRecv = 76
|
|
tagSend = 77
|
|
tagGetSocketName = 78
|
|
|
|
data Step
|
|
= Halt Runtime T
|
|
| Continue Machine
|
|
| ForkRequested T Machine
|
|
| AwaitRequested TaskId Machine
|
|
| YieldRequested Machine
|
|
| SleepRequested Integer Machine
|
|
| AsyncAction (IO T) Machine
|
|
|
|
instance Show Step where
|
|
show (Halt _ v) = "Halt _ (" ++ show v ++ ")"
|
|
show (Continue m) = "Continue (" ++ show m ++ ")"
|
|
show (ForkRequested t m) = "ForkRequested (" ++ show t ++ ") (" ++ show m ++ ")"
|
|
show (AwaitRequested tid m) = "AwaitRequested " ++ show tid ++ " (" ++ show m ++ ")"
|
|
show (YieldRequested m) = "YieldRequested (" ++ show m ++ ")"
|
|
show (SleepRequested n m) = "SleepRequested " ++ show n ++ " (" ++ show m ++ ")"
|
|
show (AsyncAction _ m) = "AsyncAction <io> (" ++ show m ++ ")"
|
|
|
|
decodeAction :: T -> Either String Action
|
|
decodeAction tree =
|
|
case tree of
|
|
Fork tag payload ->
|
|
case toNumber tag of
|
|
Right n | n == tagPure ->
|
|
Right (APure payload)
|
|
|
|
Right n | n == tagBind ->
|
|
case payload of
|
|
Fork left k -> Right (ABind left k)
|
|
_ -> Left "Invalid Bind: expected pair action continuation"
|
|
|
|
Right n | n == tagPutStr ->
|
|
Right (APutStr payload)
|
|
|
|
Right n | n == tagPutBytes ->
|
|
Right (APutBytes payload)
|
|
|
|
Right n | n == tagGetLine ->
|
|
Right AGetLine
|
|
|
|
Right n | n == tagReadFile ->
|
|
Right (AReadFile payload)
|
|
|
|
Right n | n == tagWriteFile ->
|
|
case payload of
|
|
Fork path contents -> Right (AWriteFile path contents)
|
|
_ -> Left "Invalid WriteFile: expected pair path contents"
|
|
|
|
Right n | n == tagWriteBytes ->
|
|
case payload of
|
|
Fork path contents -> Right (AWriteBytes path contents)
|
|
_ -> Left "Invalid WriteBytes: expected pair path contents"
|
|
|
|
Right n | n == tagListDirectory ->
|
|
Right (AListDirectory payload)
|
|
|
|
Right n | n == tagRenameFile ->
|
|
case payload of
|
|
Fork old new -> Right (ARenameFile old new)
|
|
_ -> Left "Invalid RenameFile: expected pair oldPath newPath"
|
|
|
|
Right n | n == tagCreateDirectory ->
|
|
Right (ACreateDirectory payload)
|
|
|
|
Right n | n == tagDeleteFile ->
|
|
Right (ADeleteFile payload)
|
|
|
|
Right n | n == tagFileExists ->
|
|
Right (AFileExists payload)
|
|
|
|
Right n | n == tagSha256Hex ->
|
|
Right (ASha256Hex payload)
|
|
|
|
Right n | n == tagCurrentTime ->
|
|
Right ACurrentTime
|
|
|
|
Right n | n == tagAsk ->
|
|
Right AAsk
|
|
|
|
Right n | n == tagLocal ->
|
|
case payload of
|
|
Fork f action -> Right (ALocal f action)
|
|
_ -> Left "Invalid Local: expected pair function action"
|
|
|
|
Right n | n == tagGet ->
|
|
Right AGet
|
|
|
|
Right n | n == tagPut ->
|
|
Right (APut payload)
|
|
|
|
Right n | n == tagFork ->
|
|
Right (AFork payload)
|
|
|
|
Right n | n == tagAwait ->
|
|
Right (AAwait payload)
|
|
|
|
Right n | n == tagYield ->
|
|
Right AYield
|
|
|
|
Right n | n == tagSleep ->
|
|
Right (ASleep payload)
|
|
|
|
Right n | n == tagSocket ->
|
|
Right ASocket
|
|
|
|
Right n | n == tagCloseSocket ->
|
|
Right (ACloseSocket payload)
|
|
|
|
Right n | n == tagBindSocket ->
|
|
case payload of
|
|
Fork sock (Fork addr port) -> Right (ABindSocket sock addr port)
|
|
_ -> Left "Invalid BindSocket: expected pair sock (pair addr port)"
|
|
|
|
Right n | n == tagListen ->
|
|
case payload of
|
|
Fork sock backlog -> Right (AListen sock backlog)
|
|
_ -> Left "Invalid Listen: expected pair sock backlog"
|
|
|
|
Right n | n == tagAccept ->
|
|
Right (AAccept payload)
|
|
|
|
Right n | n == tagConnect ->
|
|
case payload of
|
|
Fork sock (Fork addr port) -> Right (AConnect sock addr port)
|
|
_ -> Left "Invalid Connect: expected pair sock (pair addr port)"
|
|
|
|
Right n | n == tagRecv ->
|
|
case payload of
|
|
Fork sock maxBytes -> Right (ARecv sock maxBytes)
|
|
_ -> Left "Invalid Recv: expected pair sock maxBytes"
|
|
|
|
Right n | n == tagSend ->
|
|
case payload of
|
|
Fork sock bytes -> Right (ASend sock bytes)
|
|
_ -> Left "Invalid Send: expected pair sock bytes"
|
|
|
|
Right n | n == tagGetSocketName ->
|
|
Right (AGetSocketName payload)
|
|
|
|
Right n ->
|
|
Left $ "Unknown IO action tag: " ++ show n
|
|
|
|
Left err ->
|
|
Left $ "Invalid action tag: " ++ err
|
|
|
|
_ ->
|
|
Left $ "Invalid action tree: expected pair tag payload, got " ++ show tree
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Small-step IO machine
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
finishValue :: Machine -> T -> IO Step
|
|
finishValue machine value =
|
|
case machineFrames machine of
|
|
[] ->
|
|
pure (Halt (machineRuntime machine) value)
|
|
|
|
BindFrame k : rest ->
|
|
pure (Continue machine
|
|
{ machineCurrent = apply k value
|
|
, machineFrames = rest
|
|
})
|
|
|
|
LocalFrame oldEnv : rest ->
|
|
let runtime' = (machineRuntime machine) { rtEnv = oldEnv }
|
|
in pure (Continue machine
|
|
{ machineRuntime = runtime'
|
|
, machineCurrent = pureAction value
|
|
, machineFrames = rest
|
|
})
|
|
|
|
stepMachine :: TVar SocketRegistry -> Machine -> IO Step
|
|
stepMachine sockVar machine =
|
|
case decodeAction (machineCurrent machine) of
|
|
Right action -> dispatch action
|
|
Left _ -> finishValue machine (errResult "invalid action")
|
|
where
|
|
dispatch action = case action of
|
|
APure val ->
|
|
finishValue machine val
|
|
|
|
ABind left k ->
|
|
pure (Continue machine
|
|
{ machineCurrent = left
|
|
, machineFrames = BindFrame k : machineFrames machine
|
|
})
|
|
|
|
APutStr str ->
|
|
case decodeString str "PutStr" of
|
|
Right s ->
|
|
pure (AsyncAction (putStr s >> pure Leaf) machine)
|
|
Left _ ->
|
|
finishValue machine (errResult "invalid string")
|
|
|
|
APutBytes bs ->
|
|
case decodeBytes bs "PutBytes" of
|
|
Right b ->
|
|
pure (AsyncAction (BS.putStr b >> pure Leaf) machine)
|
|
Left _ ->
|
|
finishValue machine (errResult "invalid bytes")
|
|
|
|
AGetLine ->
|
|
pure (AsyncAction (ofString <$> getLine) machine)
|
|
|
|
AReadFile path ->
|
|
case decodeString path "ReadFile" of
|
|
Right p -> do
|
|
mDeny <- checkReadPerm p
|
|
case mDeny of
|
|
Just denied -> finishValue machine denied
|
|
Nothing -> pure (AsyncAction (tryReadFile p) machine)
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
AWriteFile path contents ->
|
|
case decodeString path "WriteFile" of
|
|
Right p ->
|
|
case decodeString contents "WriteFile" of
|
|
Right c -> do
|
|
mDeny <- checkWritePerm p
|
|
case mDeny of
|
|
Just denied -> finishValue machine denied
|
|
Nothing -> pure (AsyncAction (tryWriteFile p c) machine)
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
AWriteBytes path contents ->
|
|
case decodeString path "WriteBytes" of
|
|
Right p ->
|
|
case decodeBytes contents "WriteBytes" of
|
|
Right c -> do
|
|
mDeny <- checkWritePerm p
|
|
case mDeny of
|
|
Just denied -> finishValue machine denied
|
|
Nothing -> pure (AsyncAction (tryWriteFileBytes p c) machine)
|
|
Left _ -> finishValue machine (errResult "invalid bytes")
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
AListDirectory pathTree ->
|
|
case decodeString pathTree "ListDirectory" of
|
|
Right p -> do
|
|
mDeny <- checkReadPerm p
|
|
case mDeny of
|
|
Just denied -> finishValue machine denied
|
|
Nothing -> pure (AsyncAction (tryListDirectory p) machine)
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
ARenameFile oldTree newTree ->
|
|
case decodeString oldTree "RenameFile" of
|
|
Right old ->
|
|
case decodeString newTree "RenameFile" of
|
|
Right new -> do
|
|
mDenyOld <- checkWritePerm old
|
|
mDenyNew <- checkWritePerm new
|
|
case (mDenyOld, mDenyNew) of
|
|
(Just denied, _) -> finishValue machine denied
|
|
(_, Just denied) -> finishValue machine denied
|
|
(Nothing, Nothing) -> pure (AsyncAction (tryRenameFile old new) machine)
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
ACreateDirectory pathTree ->
|
|
case decodeString pathTree "CreateDirectory" of
|
|
Right p -> do
|
|
mDeny <- checkWritePerm p
|
|
case mDeny of
|
|
Just denied -> finishValue machine denied
|
|
Nothing -> pure (AsyncAction (tryCreateDirectory p) machine)
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
ADeleteFile pathTree ->
|
|
case decodeString pathTree "DeleteFile" of
|
|
Right p -> do
|
|
mDeny <- checkWritePerm p
|
|
case mDeny of
|
|
Just denied -> finishValue machine denied
|
|
Nothing -> pure (AsyncAction (tryDeleteFile p) machine)
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
AFileExists pathTree ->
|
|
case decodeString pathTree "FileExists" of
|
|
Right p -> do
|
|
mDeny <- checkReadPerm p
|
|
case mDeny of
|
|
Just denied -> finishValue machine denied
|
|
Nothing -> pure (AsyncAction (tryFileExists p) machine)
|
|
Left _ -> finishValue machine (errResult "invalid string")
|
|
|
|
ASha256Hex bytesTree ->
|
|
case decodeBytes bytesTree "Sha256Hex" of
|
|
Right bs -> pure (AsyncAction (pure $ trySha256Hex bs) machine)
|
|
Left _ -> finishValue machine (errResult "invalid bytes")
|
|
|
|
ACurrentTime ->
|
|
pure (AsyncAction (tryCurrentTime) machine)
|
|
|
|
AAsk ->
|
|
finishValue machine (rtEnv (machineRuntime machine))
|
|
|
|
ALocal f action' ->
|
|
let runtime = machineRuntime machine
|
|
oldEnv = rtEnv runtime
|
|
newEnv = apply f oldEnv
|
|
runtime' = runtime { rtEnv = newEnv }
|
|
in pure (Continue machine
|
|
{ machineRuntime = runtime'
|
|
, machineCurrent = action'
|
|
, machineFrames = LocalFrame oldEnv : machineFrames machine
|
|
})
|
|
|
|
AGet ->
|
|
finishValue machine (rtState (machineRuntime machine))
|
|
|
|
APut newState ->
|
|
let runtime' = (machineRuntime machine) { rtState = newState }
|
|
in finishValue (machine { machineRuntime = runtime' }) Leaf
|
|
|
|
AFork childAction ->
|
|
pure (ForkRequested childAction machine)
|
|
|
|
AAwait handleTree ->
|
|
case decodeTaskHandle handleTree of
|
|
Right taskId ->
|
|
pure (AwaitRequested taskId machine)
|
|
Left _ ->
|
|
finishValue machine invalidAsyncHandleResult
|
|
|
|
AYield ->
|
|
pure (YieldRequested machine)
|
|
|
|
ASleep msTree ->
|
|
case toNumber msTree of
|
|
Right ms | ms >= 0 ->
|
|
pure (SleepRequested ms machine)
|
|
_ ->
|
|
finishValue machine invalidSleepResult
|
|
|
|
ASocket -> do
|
|
result <- try (NS.socket NS.AF_INET NS.Stream NS.defaultProtocol) :: IO (Either SomeException NS.Socket)
|
|
case result of
|
|
Left e ->
|
|
finishValue machine (errResult ("io error: " ++ show e))
|
|
Right sock -> do
|
|
NS.setSocketOption sock NS.ReuseAddr 1
|
|
sid <- atomically $ do
|
|
SocketRegistry m next <- readTVar sockVar
|
|
let sid = SockId next
|
|
writeTVar sockVar (SocketRegistry (Map.insert sid sock m) (next + 1))
|
|
return sid
|
|
finishValue machine (okResult (sockHandle sid))
|
|
|
|
ACloseSocket sockTree ->
|
|
case decodeSockHandle sockTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right sid -> do
|
|
mSock <- atomically $ do
|
|
SocketRegistry m next <- readTVar sockVar
|
|
case Map.lookup sid m of
|
|
Nothing -> return Nothing
|
|
Just sock -> do
|
|
writeTVar sockVar (SocketRegistry (Map.delete sid m) next)
|
|
return (Just sock)
|
|
case mSock of
|
|
Nothing -> finishValue machine invalidSocketHandleResult
|
|
Just sock -> do
|
|
NS.close sock
|
|
finishValue machine (okResult Leaf)
|
|
|
|
ABindSocket sockTree addrTree portTree ->
|
|
case decodeSockHandle sockTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right sid ->
|
|
case decodeString addrTree "BindSocket" of
|
|
Left _ -> finishValue machine (errResult "invalid address")
|
|
Right addrStr ->
|
|
case toNumber portTree of
|
|
Left _ -> finishValue machine (errResult "invalid port")
|
|
Right port -> do
|
|
mSock <- atomically $ do
|
|
SocketRegistry m _ <- readTVar sockVar
|
|
return (Map.lookup sid m)
|
|
case mSock of
|
|
Nothing -> finishValue machine invalidSocketHandleResult
|
|
Just sock -> do
|
|
result <- try (do
|
|
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
|
(Just addrStr)
|
|
(Just (show port))
|
|
let serverAddr = head addrInfo
|
|
NS.bind sock (NS.addrAddress serverAddr)
|
|
) :: IO (Either SomeException ())
|
|
case result of
|
|
Left e ->
|
|
finishValue machine (errResult ("io error: " ++ show e))
|
|
Right () ->
|
|
finishValue machine (okResult Leaf)
|
|
|
|
AListen sockTree backlogTree ->
|
|
case decodeSockHandle sockTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right sid ->
|
|
case toNumber backlogTree of
|
|
Left _ -> finishValue machine (errResult "invalid backlog")
|
|
Right backlog -> do
|
|
mSock <- atomically $ do
|
|
SocketRegistry m _ <- readTVar sockVar
|
|
return (Map.lookup sid m)
|
|
case mSock of
|
|
Nothing -> finishValue machine invalidSocketHandleResult
|
|
Just sock -> do
|
|
result <- try (NS.listen sock (fromIntegral backlog)) :: IO (Either SomeException ())
|
|
case result of
|
|
Left e ->
|
|
finishValue machine (errResult ("io error: " ++ show e))
|
|
Right () ->
|
|
finishValue machine (okResult Leaf)
|
|
|
|
AAccept listenTree ->
|
|
case decodeSockHandle listenTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right listenSid ->
|
|
pure (AsyncAction (do
|
|
mListenSock <- atomically $ do
|
|
SocketRegistry m _ <- readTVar sockVar
|
|
return (Map.lookup listenSid m)
|
|
case mListenSock of
|
|
Nothing -> return (errResult "invalid socket handle")
|
|
Just listenSock -> do
|
|
result <- try (NS.accept listenSock) :: IO (Either SomeException (NS.Socket, NS.SockAddr))
|
|
case result of
|
|
Left e ->
|
|
return (errResult ("io error: " ++ show e))
|
|
Right (clientSock, addr) -> do
|
|
clientSid <- atomically $ do
|
|
SocketRegistry m next <- readTVar sockVar
|
|
let sid = SockId next
|
|
writeTVar sockVar (SocketRegistry (Map.insert sid clientSock m) (next + 1))
|
|
return sid
|
|
let addrStr = case addr of
|
|
NS.SockAddrInet p h ->
|
|
let (a,b,c,d) = NS.hostAddressToTuple h
|
|
in show a ++ "." ++ show b ++ "." ++ show c ++ "." ++ show d ++ ":" ++ show p
|
|
_ -> show addr
|
|
return (okResult (Fork (sockHandle clientSid) (ofString addrStr)))
|
|
) machine)
|
|
|
|
AConnect sockTree addrTree portTree ->
|
|
case decodeSockHandle sockTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right sid ->
|
|
case decodeString addrTree "Connect" of
|
|
Left _ -> finishValue machine (errResult "invalid address")
|
|
Right addrStr ->
|
|
case toNumber portTree of
|
|
Left _ -> finishValue machine (errResult "invalid port")
|
|
Right port -> do
|
|
mSock <- atomically $ do
|
|
SocketRegistry m _ <- readTVar sockVar
|
|
return (Map.lookup sid m)
|
|
case mSock of
|
|
Nothing -> finishValue machine invalidSocketHandleResult
|
|
Just sock ->
|
|
pure (AsyncAction (do
|
|
result <- try (do
|
|
addrInfo <- NS.getAddrInfo (Just $ NS.defaultHints { NS.addrSocketType = NS.Stream })
|
|
(Just addrStr)
|
|
(Just (show port))
|
|
let serverAddr = head addrInfo
|
|
NS.connect sock (NS.addrAddress serverAddr)
|
|
) :: IO (Either SomeException ())
|
|
case result of
|
|
Left e ->
|
|
return (errResult ("io error: " ++ show e))
|
|
Right () ->
|
|
return (okResult Leaf)
|
|
) machine)
|
|
|
|
ARecv sockTree maxBytesTree ->
|
|
case decodeSockHandle sockTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right sid ->
|
|
case toNumber maxBytesTree of
|
|
Left _ -> finishValue machine (errResult "invalid maxBytes")
|
|
Right maxBytes -> do
|
|
mSock <- atomically $ do
|
|
SocketRegistry m _ <- readTVar sockVar
|
|
return (Map.lookup sid m)
|
|
case mSock of
|
|
Nothing -> finishValue machine invalidSocketHandleResult
|
|
Just sock ->
|
|
pure (AsyncAction (do
|
|
result <- try (NSB.recv sock (fromIntegral maxBytes)) :: IO (Either SomeException BS.ByteString)
|
|
case result of
|
|
Left e ->
|
|
return (errResult ("io error: " ++ show e))
|
|
Right bs ->
|
|
if BS.null bs
|
|
then return (errResult "connection closed")
|
|
else return (okResult (ofBytes bs))
|
|
) machine)
|
|
|
|
AGetSocketName sockTree ->
|
|
case decodeSockHandle sockTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right sid -> do
|
|
mSock <- atomically $ do
|
|
SocketRegistry m _ <- readTVar sockVar
|
|
return (Map.lookup sid m)
|
|
case mSock of
|
|
Nothing -> finishValue machine invalidSocketHandleResult
|
|
Just sock -> do
|
|
mPort <- getSocketPort sock
|
|
case mPort of
|
|
Just port -> finishValue machine (okResult (ofNumber port))
|
|
Nothing -> finishValue machine (errResult "io error: could not get socket name")
|
|
|
|
ASend sockTree bytesTree ->
|
|
case decodeSockHandle sockTree of
|
|
Left _ -> finishValue machine invalidSocketHandleResult
|
|
Right sid ->
|
|
case decodeBytes bytesTree "Send" of
|
|
Left _ -> finishValue machine (errResult "invalid bytes")
|
|
Right bs -> do
|
|
mSock <- atomically $ do
|
|
SocketRegistry m _ <- readTVar sockVar
|
|
return (Map.lookup sid m)
|
|
case mSock of
|
|
Nothing -> finishValue machine invalidSocketHandleResult
|
|
Just sock ->
|
|
pure (AsyncAction (do
|
|
result <- try (NSB.send sock bs) :: IO (Either SomeException Int)
|
|
case result of
|
|
Left e ->
|
|
return (errResult ("io error: " ++ show e))
|
|
Right sent ->
|
|
return (okResult (ofNumber (fromIntegral sent)))
|
|
) machine)
|
|
|
|
-- Permission and IO helpers
|
|
checkReadPerm p =
|
|
if allowReadAll (rtPerms (machineRuntime machine))
|
|
then return Nothing
|
|
else do
|
|
mp <- canonicalizeSafe p
|
|
case mp of
|
|
Left _ -> return $ Just policyErrResult
|
|
Right path -> do
|
|
allowed <- pathAllowed path (allowRead (rtPerms (machineRuntime machine)))
|
|
if allowed
|
|
then return Nothing
|
|
else return $ Just policyErrResult
|
|
|
|
checkWritePerm p =
|
|
if allowWriteAll (rtPerms (machineRuntime machine))
|
|
then return Nothing
|
|
else do
|
|
mp <- canonicalizeSafe p
|
|
case mp of
|
|
Left _ -> return $ Just policyErrResult
|
|
Right path -> do
|
|
allowed <- pathAllowed path (allowWrite (rtPerms (machineRuntime machine)))
|
|
if allowed
|
|
then return Nothing
|
|
else return $ Just policyErrResult
|
|
|
|
policyErrResult = errResult "permission denied"
|
|
|
|
canonicalizeSafe :: FilePath -> IO (Either String FilePath)
|
|
canonicalizeSafe p = do
|
|
exists <- doesPathExist p
|
|
if exists
|
|
then do
|
|
result <- try (canonicalizePath p) :: IO (Either SomeException FilePath)
|
|
case result of
|
|
Right canon -> return $ Right canon
|
|
Left _ -> normalizeSyntactic p
|
|
else normalizeSyntactic p
|
|
|
|
normalizeSyntactic :: FilePath -> IO (Either String FilePath)
|
|
normalizeSyntactic p = do
|
|
absPath <- if isRelative p then (</> p) <$> getCurrentDirectory else return p
|
|
let norm = normalise absPath
|
|
dirs = splitDirectories norm
|
|
if ".." `elem` dirs
|
|
then return $ Left "Path contains unresolved parent-directory references"
|
|
else return $ Right norm
|
|
|
|
pathAllowed :: FilePath -> [FilePath] -> IO Bool
|
|
pathAllowed _ [] = return False
|
|
pathAllowed p prefixes = do
|
|
let validPrefixes = filter (not . null) prefixes
|
|
if null validPrefixes
|
|
then return False
|
|
else do
|
|
absPrefixes <- mapM resolvePrefix validPrefixes
|
|
return $ any (isPathPrefixOf p) absPrefixes
|
|
|
|
resolvePrefix :: FilePath -> IO FilePath
|
|
resolvePrefix p = do
|
|
let norm = normalise p
|
|
absPath <- if isRelative norm then (</> norm) <$> getCurrentDirectory else return norm
|
|
exists <- doesPathExist absPath
|
|
if exists
|
|
then do
|
|
result <- try (canonicalizePath absPath) :: IO (Either SomeException FilePath)
|
|
case result of
|
|
Right canon -> return canon
|
|
Left _ -> return absPath
|
|
else return absPath
|
|
|
|
isPathPrefixOf :: FilePath -> FilePath -> Bool
|
|
isPathPrefixOf path prefix =
|
|
let prefix' = addTrailingPathSeparator prefix
|
|
in path == prefix || prefix' `isPrefixOf` path
|
|
|
|
tryReadFile path = do
|
|
result <- try (BS.readFile path) :: IO (Either IOException BS.ByteString)
|
|
case result of
|
|
Right content -> return $ okResult (ofBytes content)
|
|
Left e -> return $ errResult (ioErrorString e)
|
|
|
|
tryWriteFile path contents = do
|
|
result <- try (IO.writeFile path contents) :: IO (Either IOException ())
|
|
case result of
|
|
Right () -> return $ okResult Leaf
|
|
Left e -> return $ errResult (ioErrorString e)
|
|
|
|
tryWriteFileBytes path contents = do
|
|
result <- try (BS.writeFile path contents) :: IO (Either IOException ())
|
|
case result of
|
|
Right () -> return $ okResult Leaf
|
|
Left e -> return $ errResult (ioErrorString e)
|
|
|
|
tryListDirectory path = do
|
|
exists <- doesPathExist path
|
|
if not exists
|
|
then return $ errResult "does not exist"
|
|
else do
|
|
isDir <- doesDirectoryExist path
|
|
if not isDir
|
|
then return $ errResult "not a directory"
|
|
else do
|
|
result <- try (listDirectory path) :: IO (Either IOException [FilePath])
|
|
case result of
|
|
Right entries ->
|
|
let filtered = filter (`notElem` [".", ".."]) entries
|
|
in return $ okResult (ofList (map ofString filtered))
|
|
Left e -> return $ errResult (ioErrorString e)
|
|
|
|
tryRenameFile old new = do
|
|
oldExists <- doesPathExist old
|
|
if not oldExists
|
|
then return $ errResult "does not exist"
|
|
else do
|
|
result <- try (renameFile old new) :: IO (Either IOException ())
|
|
case result of
|
|
Right () -> return $ okResult Leaf
|
|
Left e
|
|
| isDoesNotExistError e -> return $ errResult "does not exist"
|
|
| isPermissionError e -> return $ errResult "permission denied"
|
|
| "cross-device" `isInfixOf` map toLower (show e) || "exdev" `isInfixOf` map toLower (show e) ->
|
|
return $ errResult "cross-device rename"
|
|
| otherwise -> return $ errResult (ioErrorString e)
|
|
|
|
tryCreateDirectory path = do
|
|
exists <- doesPathExist path
|
|
if exists
|
|
then do
|
|
isDir <- doesDirectoryExist path
|
|
if isDir
|
|
then return $ okResult Leaf
|
|
else return $ errResult "already exists"
|
|
else do
|
|
let parent = takeDirectory path
|
|
parentExists <- doesPathExist parent
|
|
if parentExists
|
|
then do
|
|
parentIsDir <- doesDirectoryExist parent
|
|
if parentIsDir
|
|
then do
|
|
result <- try (createDirectory path) :: IO (Either IOException ())
|
|
case result of
|
|
Right () -> return $ okResult Leaf
|
|
Left e
|
|
| isDoesNotExistError e -> return $ errResult "does not exist"
|
|
| isPermissionError e -> return $ errResult "permission denied"
|
|
| isAlreadyExistsError e -> return $ errResult "already exists"
|
|
| otherwise -> return $ errResult (ioErrorString e)
|
|
else return $ errResult "not a directory"
|
|
else do
|
|
result <- try (createDirectory path) :: IO (Either IOException ())
|
|
case result of
|
|
Right () -> return $ okResult Leaf
|
|
Left e
|
|
| isDoesNotExistError e -> return $ errResult "does not exist"
|
|
| isPermissionError e -> return $ errResult "permission denied"
|
|
| isAlreadyExistsError e -> return $ errResult "already exists"
|
|
| otherwise -> return $ errResult (ioErrorString e)
|
|
|
|
tryDeleteFile path = do
|
|
exists <- doesPathExist path
|
|
if not exists
|
|
then return $ okResult Leaf
|
|
else do
|
|
isDir <- doesDirectoryExist path
|
|
if isDir
|
|
then return $ errResult "is a directory"
|
|
else do
|
|
result <- try (removeFile path) :: IO (Either IOException ())
|
|
case result of
|
|
Right () -> return $ okResult Leaf
|
|
Left e
|
|
| isDoesNotExistError e -> return $ okResult Leaf
|
|
| isPermissionError e -> return $ errResult "permission denied"
|
|
| otherwise -> return $ errResult (ioErrorString e)
|
|
|
|
tryFileExists path = do
|
|
result <- try (doesPathExist path) :: IO (Either IOException Bool)
|
|
case result of
|
|
Right exists -> return $ okResult (if exists then Stem Leaf else Leaf)
|
|
Left e
|
|
| isPermissionError e -> return $ errResult "permission denied"
|
|
| otherwise -> return $ errResult (ioErrorString e)
|
|
|
|
trySha256Hex bs =
|
|
let digest = hash bs :: Digest SHA256
|
|
hexBs = encode (convert digest)
|
|
hexStr = T.unpack (decodeUtf8 hexBs)
|
|
in okResult (ofString hexStr)
|
|
|
|
tryCurrentTime = do
|
|
now <- getPOSIXTime
|
|
return $ okResult (ofNumber (floor now))
|
|
|
|
decodeString t ctx =
|
|
case toString t of
|
|
Right s -> Right s
|
|
Left _ -> Left $ "Invalid " ++ ctx ++ " string"
|
|
|
|
decodeBytes t ctx =
|
|
case toBytes t of
|
|
Right b -> Right b
|
|
Left _ -> Left $ "Invalid " ++ ctx ++ " bytes"
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Scheduler
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
data TaskStatus
|
|
= Runnable Machine
|
|
| BlockedOn TaskId Machine
|
|
| Sleeping UTCTime Machine
|
|
| AsyncWaiting Machine
|
|
deriving (Show)
|
|
|
|
data Scheduler = Scheduler
|
|
{ schedulerNextTaskId :: Integer
|
|
, schedulerRunnable :: Seq TaskId
|
|
, schedulerTasks :: Map TaskId TaskStatus
|
|
, schedulerWaiters :: Map TaskId (Seq TaskId)
|
|
, schedulerSleepQueue :: Map UTCTime (Set TaskId)
|
|
, schedulerAsyncCompleted :: TVar (Map TaskId T)
|
|
, schedulerCompleted :: Map TaskId (T, T)
|
|
, schedulerSockets :: TVar SocketRegistry
|
|
, schedulerNextSockId :: Integer
|
|
}
|
|
|
|
instance Show Scheduler where
|
|
show s = "Scheduler { schedulerNextTaskId = " ++ show (schedulerNextTaskId s)
|
|
++ ", schedulerRunnable = " ++ show (schedulerRunnable s)
|
|
++ ", schedulerTasks = " ++ show (schedulerTasks s)
|
|
++ ", schedulerWaiters = " ++ show (schedulerWaiters s)
|
|
++ ", schedulerSleepQueue = " ++ show (schedulerSleepQueue s)
|
|
++ ", schedulerAsyncCompleted = <tvar>"
|
|
++ ", schedulerCompleted = " ++ show (schedulerCompleted s)
|
|
++ ", schedulerSockets = <tvar>"
|
|
++ ", schedulerNextSockId = " ++ show (schedulerNextSockId s)
|
|
++ " }"
|
|
|
|
initialScheduler :: TVar (Map TaskId T) -> TVar SocketRegistry -> Machine -> Scheduler
|
|
initialScheduler asyncVar sockVar mainMachine =
|
|
Scheduler
|
|
{ schedulerNextTaskId = 1
|
|
, schedulerRunnable = Seq.singleton (TaskId 0)
|
|
, schedulerTasks = Map.singleton (TaskId 0) (Runnable mainMachine)
|
|
, schedulerWaiters = Map.empty
|
|
, schedulerSleepQueue = Map.empty
|
|
, schedulerAsyncCompleted = asyncVar
|
|
, schedulerCompleted = Map.empty
|
|
, schedulerSockets = sockVar
|
|
, schedulerNextSockId = 0
|
|
}
|
|
|
|
runtimeOfStatus :: TaskStatus -> Maybe Runtime
|
|
runtimeOfStatus (Runnable machine) = Just (machineRuntime machine)
|
|
runtimeOfStatus (BlockedOn _ machine) = Just (machineRuntime machine)
|
|
runtimeOfStatus (Sleeping _ machine) = Just (machineRuntime machine)
|
|
runtimeOfStatus (AsyncWaiting machine) = Just (machineRuntime machine)
|
|
|
|
wakeAwaiters :: TaskId -> T -> Scheduler -> Scheduler
|
|
wakeAwaiters targetId value scheduler =
|
|
case Map.lookup targetId (schedulerWaiters scheduler) of
|
|
Nothing -> scheduler
|
|
Just waiters ->
|
|
let (tasks', queue') = Fold.foldl' (wakeOne targetId value)
|
|
(schedulerTasks scheduler, schedulerRunnable scheduler)
|
|
waiters
|
|
in scheduler
|
|
{ schedulerTasks = tasks'
|
|
, schedulerRunnable = queue'
|
|
, schedulerWaiters = Map.delete targetId (schedulerWaiters scheduler)
|
|
}
|
|
where
|
|
wakeOne _ _ (tasks, queue) waiterId =
|
|
case Map.lookup waiterId tasks of
|
|
Just (BlockedOn _ machine) ->
|
|
let machine' = machine { machineCurrent = pureAction value }
|
|
in (Map.insert waiterId (Runnable machine') tasks, queue |> waiterId)
|
|
_ -> (tasks, queue)
|
|
|
|
wakeDueSleepers :: Scheduler -> IO Scheduler
|
|
wakeDueSleepers scheduler = do
|
|
now <- getCurrentTime
|
|
let go sq accTasks accQueue =
|
|
case Map.lookupMin sq of
|
|
Nothing -> (accTasks, accQueue, sq)
|
|
Just (t, taskSet)
|
|
| t <= now ->
|
|
let tasks' = Fold.foldl' (\m tid ->
|
|
case Map.lookup tid m of
|
|
Just (Sleeping _ machine) -> Map.insert tid (Runnable machine) m
|
|
_ -> m
|
|
) accTasks (Set.toList taskSet)
|
|
queue' = Fold.foldl' (|>) accQueue (Set.toList taskSet)
|
|
in go (Map.deleteMin sq) tasks' queue'
|
|
| otherwise -> (accTasks, accQueue, sq)
|
|
(tasks', queue', sq') = go (schedulerSleepQueue scheduler)
|
|
(schedulerTasks scheduler)
|
|
(schedulerRunnable scheduler)
|
|
pure scheduler
|
|
{ schedulerTasks = tasks'
|
|
, schedulerRunnable = queue'
|
|
, schedulerSleepQueue = sq'
|
|
}
|
|
|
|
nearestSleepTime :: Scheduler -> Maybe UTCTime
|
|
nearestSleepTime = fmap fst . Map.lookupMin . schedulerSleepQueue
|
|
|
|
hasAsyncWaiters :: Scheduler -> Bool
|
|
hasAsyncWaiters = any isAsync . Map.elems . schedulerTasks
|
|
where
|
|
isAsync (AsyncWaiting _) = True
|
|
isAsync _ = False
|
|
|
|
resumeCurrentWith :: TaskId -> T -> Machine -> Scheduler -> IO Scheduler
|
|
resumeCurrentWith taskId value machine scheduler =
|
|
let machine' = machine { machineCurrent = pureAction value }
|
|
in pure scheduler
|
|
{ schedulerTasks = Map.insert taskId (Runnable machine') (schedulerTasks scheduler)
|
|
, schedulerRunnable = schedulerRunnable scheduler |> taskId
|
|
}
|
|
|
|
wouldCycle :: TaskId -> TaskId -> Map TaskId TaskStatus -> Bool
|
|
wouldCycle target current tasks =
|
|
case Map.lookup target tasks of
|
|
Just (BlockedOn next _) ->
|
|
next == current || wouldCycle next current tasks
|
|
_ -> False
|
|
|
|
handleStep :: TaskId -> Step -> Scheduler -> IO Scheduler
|
|
handleStep taskId (Continue machine) scheduler =
|
|
pure scheduler
|
|
{ schedulerTasks = Map.insert taskId (Runnable machine) (schedulerTasks scheduler)
|
|
, schedulerRunnable = schedulerRunnable scheduler |> taskId
|
|
}
|
|
|
|
handleStep taskId (Halt runtime value) scheduler =
|
|
let scheduler' = wakeAwaiters taskId value scheduler
|
|
in pure scheduler'
|
|
{ schedulerTasks = Map.delete taskId (schedulerTasks scheduler')
|
|
, schedulerCompleted = Map.insert taskId (value, rtState runtime) (schedulerCompleted scheduler')
|
|
}
|
|
|
|
handleStep parentId (ForkRequested childAction parentMachine) scheduler =
|
|
let childId = TaskId (schedulerNextTaskId scheduler)
|
|
handle = taskHandle childId
|
|
|
|
parentMachine' =
|
|
parentMachine { machineCurrent = pureAction handle }
|
|
|
|
childMachine =
|
|
Machine
|
|
{ machineRuntime = machineRuntime parentMachine
|
|
, machineCurrent = childAction
|
|
, machineFrames = []
|
|
}
|
|
|
|
tasks' =
|
|
Map.insert parentId (Runnable parentMachine') $
|
|
Map.insert childId (Runnable childMachine) $
|
|
schedulerTasks scheduler
|
|
|
|
queue' =
|
|
schedulerRunnable scheduler |> parentId |> childId
|
|
|
|
in pure scheduler
|
|
{ schedulerNextTaskId = schedulerNextTaskId scheduler + 1
|
|
, schedulerTasks = tasks'
|
|
, schedulerRunnable = queue'
|
|
}
|
|
|
|
handleStep currentId (AwaitRequested targetId machine) scheduler
|
|
| currentId == targetId =
|
|
resumeCurrentWith currentId selfAwaitResult machine scheduler
|
|
|
|
| otherwise =
|
|
case Map.lookup targetId (schedulerTasks scheduler) of
|
|
Nothing ->
|
|
case Map.lookup targetId (schedulerCompleted scheduler) of
|
|
Just (value, _) -> resumeCurrentWith currentId value machine scheduler
|
|
Nothing -> resumeCurrentWith currentId invalidAsyncHandleResult machine scheduler
|
|
|
|
Just (BlockedOn nextId _) ->
|
|
if wouldCycle targetId currentId (schedulerTasks scheduler)
|
|
then resumeCurrentWith currentId (errResult "cyclic await") machine scheduler
|
|
else block
|
|
|
|
Just _ -> block
|
|
where
|
|
block = pure scheduler
|
|
{ schedulerTasks = Map.insert currentId (BlockedOn targetId machine) (schedulerTasks scheduler)
|
|
, schedulerWaiters = Map.alter addWaiter targetId (schedulerWaiters scheduler)
|
|
}
|
|
addWaiter Nothing = Just (Seq.singleton currentId)
|
|
addWaiter (Just sq) = Just (sq |> currentId)
|
|
|
|
handleStep taskId (YieldRequested machine) scheduler =
|
|
resumeCurrentWith taskId Leaf machine scheduler
|
|
|
|
handleStep taskId (SleepRequested ms machine) scheduler = do
|
|
now <- getCurrentTime
|
|
let seconds = fromIntegral ms / 1000
|
|
wakeTime = addUTCTime seconds now
|
|
machine' = machine { machineCurrent = pureAction Leaf }
|
|
pure scheduler
|
|
{ schedulerTasks = Map.insert taskId (Sleeping wakeTime machine') (schedulerTasks scheduler)
|
|
, schedulerSleepQueue = Map.alter (Just . maybe (Set.singleton taskId) (Set.insert taskId)) wakeTime (schedulerSleepQueue scheduler)
|
|
}
|
|
|
|
handleStep taskId (AsyncAction ioAction machine) scheduler = do
|
|
_ <- forkIO $ do
|
|
result <- (Right <$> ioAction) `catch` \(e :: SomeException) -> pure (Left (show e))
|
|
atomically $ modifyTVar' (schedulerAsyncCompleted scheduler) (Map.insert taskId $
|
|
case result of
|
|
Right val -> val
|
|
Left msg -> errResult msg)
|
|
pure scheduler
|
|
{ schedulerTasks = Map.insert taskId (AsyncWaiting machine) (schedulerTasks scheduler)
|
|
}
|
|
|
|
handleNoRunnable :: Scheduler -> IO Scheduler
|
|
handleNoRunnable scheduler =
|
|
case nearestSleepTime scheduler of
|
|
Just wakeTime -> do
|
|
now <- getCurrentTime
|
|
let micros = max 0 (floor (diffUTCTime wakeTime now * 1000000))
|
|
threadDelay micros
|
|
wakeDueSleepers scheduler
|
|
|
|
Nothing ->
|
|
if hasAsyncWaiters scheduler
|
|
then do
|
|
-- Block efficiently until at least one async operation completes.
|
|
atomically $ do
|
|
m <- readTVar (schedulerAsyncCompleted scheduler)
|
|
if Map.null m then retry else return ()
|
|
pure scheduler
|
|
else
|
|
case Map.lookup (TaskId 0) (schedulerTasks scheduler) of
|
|
Just status ->
|
|
case runtimeOfStatus status of
|
|
Just runtime ->
|
|
let scheduler' = wakeAwaiters (TaskId 0) deadlockResult scheduler
|
|
in pure scheduler'
|
|
{ schedulerTasks = Map.delete (TaskId 0) (schedulerTasks scheduler')
|
|
, schedulerCompleted = Map.insert (TaskId 0) (deadlockResult, rtState runtime) (schedulerCompleted scheduler')
|
|
}
|
|
Nothing -> pure scheduler
|
|
Nothing -> pure scheduler
|
|
|
|
schedulerStep :: Scheduler -> IO Scheduler
|
|
schedulerStep scheduler = do
|
|
-- Poll completed async operations and resume their tasks.
|
|
completed <- atomically $ do
|
|
m <- readTVar (schedulerAsyncCompleted scheduler)
|
|
writeTVar (schedulerAsyncCompleted scheduler) Map.empty
|
|
return m
|
|
schedulerAfterAsync <- Fold.foldlM
|
|
(\s (tid, val) ->
|
|
case Map.lookup tid (schedulerTasks s) of
|
|
Just (AsyncWaiting machine) -> resumeCurrentWith tid val machine s
|
|
_ -> pure s)
|
|
scheduler
|
|
(Map.toList completed)
|
|
|
|
scheduler1 <- wakeDueSleepers schedulerAfterAsync
|
|
case Seq.viewl (schedulerRunnable scheduler1) of
|
|
EmptyL ->
|
|
handleNoRunnable scheduler1
|
|
|
|
taskId :< restQueue ->
|
|
case Map.lookup taskId (schedulerTasks scheduler1) of
|
|
Just (Runnable machine) -> do
|
|
step <- stepMachine (schedulerSockets scheduler1) machine
|
|
handleStep taskId step scheduler1 { schedulerRunnable = restQueue }
|
|
|
|
_ ->
|
|
pure scheduler1 { schedulerRunnable = restQueue }
|
|
|
|
runScheduler :: Scheduler -> IO (T, T)
|
|
runScheduler scheduler =
|
|
case Map.lookup (TaskId 0) (schedulerCompleted scheduler) of
|
|
Just (value, finalState) ->
|
|
pure (value, finalState)
|
|
|
|
_ ->
|
|
schedulerStep scheduler >>= runScheduler
|
|
|
|
-- ---------------------------------------------------------------------------
|
|
-- Public API
|
|
-- ---------------------------------------------------------------------------
|
|
|
|
runIOWith :: IOPermissions -> T -> T -> T -> IO (Either String (T, T))
|
|
runIOWith perms env initialState action =
|
|
case checkIOSentinel action of
|
|
Left err -> pure (Left err)
|
|
Right (_, action') -> do
|
|
asyncVar <- newTVarIO Map.empty
|
|
sockVar <- newTVarIO (SocketRegistry Map.empty 0)
|
|
let initialMachine = Machine
|
|
{ machineRuntime = Runtime
|
|
{ rtPerms = perms
|
|
, rtEnv = env
|
|
, rtState = initialState
|
|
}
|
|
, machineCurrent = action'
|
|
, machineFrames = []
|
|
}
|
|
Right <$> runScheduler (initialScheduler asyncVar sockVar initialMachine)
|
|
|
|
runIOWithEnv :: IOPermissions -> T -> T -> IO (Either String T)
|
|
runIOWithEnv perms env action = do
|
|
result <- runIOWith perms env Leaf action
|
|
pure (fmap fst result)
|
|
|
|
runIO :: IOPermissions -> T -> IO (Either String T)
|
|
runIO perms action = do
|
|
result <- runIOWith perms Leaf Leaf action
|
|
pure (fmap fst result)
|