Files
tricu/ext/zig/tests/python_ffi_test.py

252 lines
8.8 KiB
Python

#!/usr/bin/env python3
"""Python FFI tests for the Arboricx C ABI.
Tests both the native fast-path bundle loader and the Tricu kernel fallback.
"""
import ctypes
import os
import sys
import time
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ZIG_DIR = os.path.dirname(SCRIPT_DIR)
lib_path = os.environ.get(
"ARBORICX_LIB",
os.path.join(ZIG_DIR, "zig-out", "lib", "libarboricx.so"),
)
lib = ctypes.CDLL(lib_path)
# --- Lifecycle ---
lib.arboricx_init.restype = ctypes.c_void_p
lib.arboricx_free.argtypes = [ctypes.c_void_p]
# --- Tree construction ---
lib.arb_leaf.argtypes = [ctypes.c_void_p]
lib.arb_leaf.restype = ctypes.c_uint32
lib.arb_stem.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
lib.arb_stem.restype = ctypes.c_uint32
lib.arb_fork.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32]
lib.arb_fork.restype = ctypes.c_uint32
lib.arb_app.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32]
lib.arb_app.restype = ctypes.c_uint32
# --- Reduction ---
lib.arb_reduce.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint64]
lib.arb_reduce.restype = ctypes.c_uint32
# --- Codecs ---
lib.arb_of_number.argtypes = [ctypes.c_void_p, ctypes.c_uint64]
lib.arb_of_number.restype = ctypes.c_uint32
lib.arb_of_string.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
lib.arb_of_string.restype = ctypes.c_uint32
lib.arb_of_bytes.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t]
lib.arb_of_bytes.restype = ctypes.c_uint32
lib.arb_of_list.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint32), ctypes.c_size_t]
lib.arb_of_list.restype = ctypes.c_uint32
lib.arb_to_number.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_uint64)]
lib.arb_to_number.restype = ctypes.c_int
lib.arb_to_string.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.POINTER(ctypes.c_uint8)), ctypes.POINTER(ctypes.c_size_t)]
lib.arb_to_string.restype = ctypes.c_int
lib.arb_to_bool.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_int)]
lib.arb_to_bool.restype = ctypes.c_int
lib.arboricx_free_buf.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t]
# --- Result unwrapping ---
lib.arb_unwrap_result.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32)]
lib.arb_unwrap_result.restype = ctypes.c_int
lib.arb_unwrap_host_value.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_uint32)]
lib.arb_unwrap_host_value.restype = ctypes.c_int
# --- Kernel ---
lib.arb_kernel_root.argtypes = [ctypes.c_void_p]
lib.arb_kernel_root.restype = ctypes.c_uint32
# --- Native bundle loading ---
lib.arb_load_bundle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t, ctypes.c_char_p]
lib.arb_load_bundle.restype = ctypes.c_uint32
lib.arb_load_bundle_default.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t]
lib.arb_load_bundle_default.restype = ctypes.c_uint32
ctx = lib.arboricx_init()
print("ctx init ok")
fixtures = os.path.join(ZIG_DIR, "..", "..", "test", "fixtures")
def read_bundle(name):
path = os.path.join(fixtures, name)
with open(path, "rb") as f:
return f.read()
def c_bytes(py_bytes):
arr = (ctypes.c_uint8 * len(py_bytes))(*py_bytes)
return arr
def to_string(ctx, root):
ptr = ctypes.POINTER(ctypes.c_uint8)()
length = ctypes.c_size_t()
if not lib.arb_to_string(ctx, root, ctypes.byref(ptr), ctypes.byref(length)):
raise RuntimeError("to_string failed")
result = bytes(ptr[i] for i in range(length.value))
lib.arboricx_free_buf(ctx, ptr, length.value)
return result.decode("utf-8")
def to_number(ctx, root):
out = ctypes.c_uint64()
if not lib.arb_to_number(ctx, root, ctypes.byref(out)):
raise RuntimeError("to_number failed")
return out.value
def to_bool(ctx, root):
out = ctypes.c_int()
if not lib.arb_to_bool(ctx, root, ctypes.byref(out)):
raise RuntimeError("to_bool failed")
return bool(out.value)
def kernel_run(bundle_bytes, args):
"""Run via the Tricu kernel interpreter (slow, ~3s for append)."""
buf = c_bytes(bundle_bytes)
bundle_tree = lib.arb_of_bytes(ctx, buf, len(bundle_bytes))
tag = lib.arb_of_number(ctx, 1)
arg_items = []
for a in args:
arg_items.append(lib.arb_of_string(ctx, a.encode("utf-8")))
current = lib.arb_leaf(ctx)
for item in reversed(arg_items):
current = lib.arb_fork(ctx, item, current)
app0 = lib.arb_app(ctx, lib.arb_kernel_root(ctx), tag)
app1 = lib.arb_app(ctx, app0, bundle_tree)
app2 = lib.arb_app(ctx, app1, current)
result = lib.arb_reduce(ctx, app2, 1_000_000_000)
ok = ctypes.c_int()
value = ctypes.c_uint32()
rest = ctypes.c_uint32()
if not lib.arb_unwrap_result(ctx, result, ctypes.byref(ok), ctypes.byref(value), ctypes.byref(rest)):
raise RuntimeError("unwrap_result failed")
tag_num = ctypes.c_uint64()
payload = ctypes.c_uint32()
if not lib.arb_unwrap_host_value(ctx, value.value, ctypes.byref(tag_num), ctypes.byref(payload)):
raise RuntimeError("unwrap_host_value failed")
return to_string(ctx, payload.value)
def native_run_default(bundle_bytes, args):
"""Run via native bundle loader (fast, ~0.01s)."""
buf = c_bytes(bundle_bytes)
term = lib.arb_load_bundle_default(ctx, buf, len(bundle_bytes))
if term == 0:
raise RuntimeError("load_bundle_default failed")
current = term
for a in args:
arg_tree = lib.arb_of_string(ctx, a.encode("utf-8"))
current = lib.arb_app(ctx, current, arg_tree)
result = lib.arb_reduce(ctx, current, 1_000_000_000)
return to_string(ctx, result)
def native_run_named(bundle_bytes, name, args):
"""Run via native bundle loader with named export (fast)."""
buf = c_bytes(bundle_bytes)
term = lib.arb_load_bundle(ctx, buf, len(bundle_bytes), name.encode("utf-8"))
if term == 0:
raise RuntimeError(f"load_bundle({name!r}) failed")
current = term
for a in args:
arg_tree = lib.arb_of_string(ctx, a.encode("utf-8"))
current = lib.arb_app(ctx, current, arg_tree)
result = lib.arb_reduce(ctx, current, 1_000_000_000)
return to_string(ctx, result)
# ============================================================================
# Tests
# ============================================================================
all_ok = True
def check(label, got, want):
global all_ok
if got != want:
print(f"FAIL {label}: got {got!r}, want {want!r}")
all_ok = False
else:
print(f"PASS {label}: {got!r}")
# Test 1: id via kernel
print("\n--- Test 1: id (kernel path) ---")
bundle = read_bundle("id.arboricx")
t0 = time.time()
result = kernel_run(bundle, ["hello"])
t1 = time.time()
check("id kernel", result, "hello")
print(f" time: {(t1 - t0) * 1000:.1f} ms")
# Test 2: id via native
print("\n--- Test 2: id (native path) ---")
t0 = time.time()
result = native_run_default(bundle, ["hello"])
t1 = time.time()
check("id native", result, "hello")
print(f" time: {(t1 - t0) * 1000:.1f} ms")
# Test 3: append via kernel
print("\n--- Test 3: append (kernel path) ---")
bundle = read_bundle("append.arboricx")
t0 = time.time()
result = kernel_run(bundle, ["Hello, ", "world!"])
t1 = time.time()
check("append kernel", result, "Hello, world!")
print(f" time: {(t1 - t0) * 1000:.1f} ms")
# Test 4: append via native
print("\n--- Test 4: append (native path) ---")
t0 = time.time()
result = native_run_default(bundle, ["Hello, ", "world!"])
t1 = time.time()
check("append native", result, "Hello, world!")
print(f" time: {(t1 - t0) * 1000:.1f} ms")
# Test 5: append via native named export
print("\n--- Test 5: append via named export 'root' ---")
t0 = time.time()
result = native_run_named(bundle, "root", ["Hello, ", "world!"])
t1 = time.time()
check("append named", result, "Hello, world!")
print(f" time: {(t1 - t0) * 1000:.1f} ms")
# Test 6: true / false via native
print("\n--- Test 6: true / false (native path) ---")
for name, expected in [("true.arboricx", True), ("false.arboricx", False)]:
bundle = read_bundle(name)
buf = c_bytes(bundle)
term = lib.arb_load_bundle_default(ctx, buf, len(bundle))
result = lib.arb_reduce(ctx, term, 1_000_000_000)
check(f"{name} bool", to_bool(ctx, result), expected)
# Test 7: number roundtrip
print("\n--- Test 7: number roundtrip ---")
num_tree = lib.arb_of_number(ctx, 42)
check("number 42", to_number(ctx, num_tree), 42)
# Test 8: string roundtrip
print("\n--- Test 8: string roundtrip ---")
str_tree = lib.arb_of_string(ctx, b"hello")
check("string hello", to_string(ctx, str_tree), "hello")
lib.arboricx_free(ctx)
if all_ok:
print("\nAll tests passed!")
sys.exit(0)
else:
print("\nSome tests failed!")
sys.exit(1)