Interaction Trees in Zig and simple benchmarks

This commit is contained in:
2026-05-15 21:41:19 -05:00
parent e3dcf5edd7
commit 8d5e76db1c
17 changed files with 2179 additions and 81 deletions

240
bench/ApplyStats.hs Normal file
View File

@@ -0,0 +1,240 @@
{-# LANGUAGE BangPatterns #-}
module ApplyStats
( ApplyStats(..)
, emptyApplyStats
, emptyApplyStatsSampled
, applyCounted
, runApplyCounted
, runApplySampledWithProgress
, runApplyGlobalCounted
, printApplyStats
) where
import Research
import qualified Data.Map.Strict as M
import qualified Data.List as L
import Data.Ord (comparing)
import Data.Text (Text)
import qualified Data.Text as T
import Debug.Trace (trace)
import System.IO.Unsafe (unsafePerformIO, unsafeDupablePerformIO)
import Data.IORef
-- ---------------------------------------------------------------------------
-- Threaded stats (slow but pure)
-- ---------------------------------------------------------------------------
type Hash = Text
type AppKey = (Hash, Hash)
data ApplyStats = ApplyStats
{ totalApplyCalls :: !Int
, uniqueApps :: !(M.Map AppKey Int)
, sampleInterval :: !Int
, sampleCounter :: !Int
, progressEvery :: !Int
}
deriving (Show)
emptyApplyStats :: ApplyStats
emptyApplyStats = emptyApplyStatsSampled 1
emptyApplyStatsSampled :: Int -> ApplyStats
emptyApplyStatsSampled n = ApplyStats
{ totalApplyCalls = 0
, uniqueApps = M.empty
, sampleInterval = max 1 n
, sampleCounter = 0
, progressEvery = 0
}
bump :: T -> T -> ApplyStats -> ApplyStats
bump !f !x !st =
let !counter' = sampleCounter st + 1
!total' = totalApplyCalls st + 1
!stBase = st { totalApplyCalls = total'
, sampleCounter = counter'
}
!st' = if counter' `mod` sampleInterval st /= 0
then stBase
else let !hf = termHash f
!hx = termHash x
!k = (hf, hx)
!m = M.insertWith (+) k 1 (uniqueApps st)
in stBase { uniqueApps = m }
in case progressEvery st of
0 -> st'
n | total' `mod` n == 0 ->
trace ("apply calls so far: " ++ show total') st'
_ -> st'
termHash :: T -> Hash
termHash Leaf =
nodeHash NLeaf
termHash (Stem t) =
nodeHash (NStem (termHash t))
termHash (Fork l r) =
nodeHash (NFork (termHash l) (termHash r))
applyCounted :: T -> T -> ApplyStats -> (T, ApplyStats)
applyCounted !f !x !st0 =
let !st1 = bump f x st0
in applyStepCounted f x st1
applyStepCounted :: T -> T -> ApplyStats -> (T, ApplyStats)
applyStepCounted (Fork Leaf a) _ st =
(a, st)
applyStepCounted (Fork (Stem a) b) c st =
let (!ac, !st1) = applyCounted a c st
(!bc, !st2) = applyCounted b c st1
in applyCounted ac bc st2
applyStepCounted (Fork (Fork a _b) _c) Leaf st =
(a, st)
applyStepCounted (Fork (Fork _a b) _c) (Stem u) st =
applyCounted b u st
applyStepCounted (Fork (Fork _a _b) c) (Fork u v) st =
let (!cu, !st1) = applyCounted c u st
in applyCounted cu v st1
applyStepCounted Leaf b st =
(Stem b, st)
applyStepCounted (Stem a) b st =
(Fork a b, st)
runApplyCounted :: T -> T -> (T, ApplyStats)
runApplyCounted !f !x =
applyCounted f x emptyApplyStats
runApplySampled :: Int -> T -> T -> (T, ApplyStats)
runApplySampled !n !f !x =
applyCounted f x (emptyApplyStatsSampled n)
runApplySampledWithProgress :: Int -> Int -> T -> T -> (T, ApplyStats)
runApplySampledWithProgress !interval !progress !f !x =
let st = (emptyApplyStatsSampled interval) { progressEvery = progress }
in applyCounted f x st
-- ---------------------------------------------------------------------------
-- Global mutable stats (fast, unsafe, single-threaded only)
-- ---------------------------------------------------------------------------
{-# NOINLINE globalTotalCount #-}
globalTotalCount :: IORef Int
globalTotalCount = unsafePerformIO (newIORef 0)
{-# NOINLINE globalInterval #-}
globalInterval :: IORef Int
globalInterval = unsafePerformIO (newIORef 1)
{-# NOINLINE globalMap #-}
globalMap :: IORef (M.Map AppKey Int)
globalMap = unsafePerformIO (newIORef M.empty)
{-# NOINLINE globalProgress #-}
globalProgress :: IORef Int
globalProgress = unsafePerformIO (newIORef 0)
resetGlobalStats :: Int -> Int -> IO ()
resetGlobalStats !interval !progress = do
writeIORef globalTotalCount 0
writeIORef globalInterval (max 1 interval)
writeIORef globalMap M.empty
writeIORef globalProgress progress
readGlobalStats :: IO ApplyStats
readGlobalStats = do
total <- readIORef globalTotalCount
m <- readIORef globalMap
pure ApplyStats
{ totalApplyCalls = total
, uniqueApps = m
, sampleInterval = 0
, sampleCounter = 0
, progressEvery = 0
}
{-# INLINE globalBump #-}
globalBump :: T -> T -> ()
globalBump !f !x = unsafeDupablePerformIO $ do
!total <- readIORef globalTotalCount
let !total' = total + 1
writeIORef globalTotalCount total'
!interval <- readIORef globalInterval
!progress <- readIORef globalProgress
let !_ = if progress > 0 && total' `mod` progress == 0
then trace ("apply calls so far: " ++ show total') ()
else ()
if total' `mod` interval /= 0
then pure ()
else do
let !hf = termHash f
!hx = termHash x
!k = (hf, hx)
!m <- readIORef globalMap
writeIORef globalMap (M.insertWith (+) k 1 m)
pure ()
applyGlobalCounted :: T -> T -> T
applyGlobalCounted !f !x =
let !_ = globalBump f x
in applyGlobalStep f x
applyGlobalStep :: T -> T -> T
applyGlobalStep (Fork Leaf a) _ = a
applyGlobalStep (Fork (Stem a) b) c =
applyGlobalCounted (applyGlobalCounted a c) (applyGlobalCounted b c)
applyGlobalStep (Fork (Fork a _b) _c) Leaf = a
applyGlobalStep (Fork (Fork _a b) _c) (Stem u) = applyGlobalCounted b u
applyGlobalStep (Fork (Fork _a _b) c) (Fork u v) =
applyGlobalCounted (applyGlobalCounted c u) v
applyGlobalStep Leaf b = Stem b
applyGlobalStep (Stem a) b = Fork a b
runApplyGlobalCounted :: Int -> Int -> T -> T -> IO (T, ApplyStats)
runApplyGlobalCounted !interval !progress !f !x = do
resetGlobalStats interval progress
let !result = applyGlobalCounted f x
!stats <- readGlobalStats
pure (result, stats)
-- ---------------------------------------------------------------------------
-- Printing
-- ---------------------------------------------------------------------------
printApplyStats :: ApplyStats -> IO ()
printApplyStats st = do
let !total = totalApplyCalls st
!uniq = M.size (uniqueApps st)
!ratio =
if uniq == 0
then 0 :: Double
else fromIntegral total / fromIntegral uniq
counts =
reverse
. L.sortBy (comparing snd)
. M.toList
$ uniqueApps st
repeated =
filter ((> 1) . snd) counts
top20 = take 20 repeated
putStrLn $ "total apply calls: " ++ show total
putStrLn $ "unique application patterns: " ++ show uniq
putStrLn $ "duplication ratio total/unique: " ++ show ratio
putStrLn $ "repeated application patterns: " ++ show (length repeated)
putStrLn "top repeated application counts:"
mapM_ printTop top20
where
short h = T.unpack (T.take 12 h)
printTop ((hf, hx), n) =
putStrLn $
" " ++ show n
++ "x apply "
++ short hf
++ " "
++ short hx