#!/usr/bin/env python3 """Python FFI tests for the Arboricx C ABI. Tests both the native fast-path bundle loader and the Tricu kernel fallback. """ import ctypes import os import sys import time SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) ZIG_DIR = os.path.dirname(SCRIPT_DIR) lib_path = os.environ.get( "ARBORICX_LIB", os.path.join(ZIG_DIR, "zig-out", "lib", "libarboricx.so"), ) lib = ctypes.CDLL(lib_path) # --- Lifecycle --- lib.arboricx_init.restype = ctypes.c_void_p lib.arboricx_free.argtypes = [ctypes.c_void_p] # --- Tree construction --- lib.arb_leaf.argtypes = [ctypes.c_void_p] lib.arb_leaf.restype = ctypes.c_uint32 lib.arb_stem.argtypes = [ctypes.c_void_p, ctypes.c_uint32] lib.arb_stem.restype = ctypes.c_uint32 lib.arb_fork.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32] lib.arb_fork.restype = ctypes.c_uint32 lib.arb_app.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32] lib.arb_app.restype = ctypes.c_uint32 # --- Reduction --- lib.arb_reduce.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint64] lib.arb_reduce.restype = ctypes.c_uint32 # --- Codecs --- lib.arb_of_number.argtypes = [ctypes.c_void_p, ctypes.c_uint64] lib.arb_of_number.restype = ctypes.c_uint32 lib.arb_of_string.argtypes = [ctypes.c_void_p, ctypes.c_char_p] lib.arb_of_string.restype = ctypes.c_uint32 lib.arb_of_bytes.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t] lib.arb_of_bytes.restype = ctypes.c_uint32 lib.arb_of_list.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint32), ctypes.c_size_t] lib.arb_of_list.restype = ctypes.c_uint32 lib.arb_to_number.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_uint64)] lib.arb_to_number.restype = ctypes.c_int lib.arb_to_string.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.POINTER(ctypes.c_uint8)), ctypes.POINTER(ctypes.c_size_t)] lib.arb_to_string.restype = ctypes.c_int lib.arb_to_bool.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_int)] lib.arb_to_bool.restype = ctypes.c_int lib.arboricx_free_buf.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t] # --- Result unwrapping --- lib.arb_unwrap_result.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_uint32), ctypes.POINTER(ctypes.c_uint32)] lib.arb_unwrap_result.restype = ctypes.c_int lib.arb_unwrap_host_value.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_uint32)] lib.arb_unwrap_host_value.restype = ctypes.c_int # --- Kernel --- lib.arb_kernel_root.argtypes = [ctypes.c_void_p] lib.arb_kernel_root.restype = ctypes.c_uint32 # --- Native bundle loading --- lib.arb_load_bundle.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t, ctypes.c_char_p] lib.arb_load_bundle.restype = ctypes.c_uint32 lib.arb_load_bundle_default.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_size_t] lib.arb_load_bundle_default.restype = ctypes.c_uint32 ctx = lib.arboricx_init() print("ctx init ok") fixtures = os.path.join(ZIG_DIR, "..", "..", "test", "fixtures") def read_bundle(name): path = os.path.join(fixtures, name) with open(path, "rb") as f: return f.read() def c_bytes(py_bytes): arr = (ctypes.c_uint8 * len(py_bytes))(*py_bytes) return arr def to_string(ctx, root): ptr = ctypes.POINTER(ctypes.c_uint8)() length = ctypes.c_size_t() if not lib.arb_to_string(ctx, root, ctypes.byref(ptr), ctypes.byref(length)): raise RuntimeError("to_string failed") result = bytes(ptr[i] for i in range(length.value)) lib.arboricx_free_buf(ctx, ptr, length.value) return result.decode("utf-8") def to_number(ctx, root): out = ctypes.c_uint64() if not lib.arb_to_number(ctx, root, ctypes.byref(out)): raise RuntimeError("to_number failed") return out.value def to_bool(ctx, root): out = ctypes.c_int() if not lib.arb_to_bool(ctx, root, ctypes.byref(out)): raise RuntimeError("to_bool failed") return bool(out.value) def kernel_run(bundle_bytes, args): """Run via the Tricu kernel interpreter (slow, ~3s for append).""" buf = c_bytes(bundle_bytes) bundle_tree = lib.arb_of_bytes(ctx, buf, len(bundle_bytes)) tag = lib.arb_of_number(ctx, 1) arg_items = [] for a in args: arg_items.append(lib.arb_of_string(ctx, a.encode("utf-8"))) current = lib.arb_leaf(ctx) for item in reversed(arg_items): current = lib.arb_fork(ctx, item, current) app0 = lib.arb_app(ctx, lib.arb_kernel_root(ctx), tag) app1 = lib.arb_app(ctx, app0, bundle_tree) app2 = lib.arb_app(ctx, app1, current) result = lib.arb_reduce(ctx, app2, 1_000_000_000) ok = ctypes.c_int() value = ctypes.c_uint32() rest = ctypes.c_uint32() if not lib.arb_unwrap_result(ctx, result, ctypes.byref(ok), ctypes.byref(value), ctypes.byref(rest)): raise RuntimeError("unwrap_result failed") tag_num = ctypes.c_uint64() payload = ctypes.c_uint32() if not lib.arb_unwrap_host_value(ctx, value.value, ctypes.byref(tag_num), ctypes.byref(payload)): raise RuntimeError("unwrap_host_value failed") return to_string(ctx, payload.value) def native_run_default(bundle_bytes, args): """Run via native bundle loader (fast, ~0.01s).""" buf = c_bytes(bundle_bytes) term = lib.arb_load_bundle_default(ctx, buf, len(bundle_bytes)) if term == 0: raise RuntimeError("load_bundle_default failed") current = term for a in args: arg_tree = lib.arb_of_string(ctx, a.encode("utf-8")) current = lib.arb_app(ctx, current, arg_tree) result = lib.arb_reduce(ctx, current, 1_000_000_000) return to_string(ctx, result) def native_run_named(bundle_bytes, name, args): """Run via native bundle loader with named export (fast).""" buf = c_bytes(bundle_bytes) term = lib.arb_load_bundle(ctx, buf, len(bundle_bytes), name.encode("utf-8")) if term == 0: raise RuntimeError(f"load_bundle({name!r}) failed") current = term for a in args: arg_tree = lib.arb_of_string(ctx, a.encode("utf-8")) current = lib.arb_app(ctx, current, arg_tree) result = lib.arb_reduce(ctx, current, 1_000_000_000) return to_string(ctx, result) # ============================================================================ # Tests # ============================================================================ all_ok = True def check(label, got, want): global all_ok if got != want: print(f"FAIL {label}: got {got!r}, want {want!r}") all_ok = False else: print(f"PASS {label}: {got!r}") # Test 1: id via kernel print("\n--- Test 1: id (kernel path) ---") bundle = read_bundle("id.arboricx") t0 = time.time() result = kernel_run(bundle, ["hello"]) t1 = time.time() check("id kernel", result, "hello") print(f" time: {(t1 - t0) * 1000:.1f} ms") # Test 2: id via native print("\n--- Test 2: id (native path) ---") t0 = time.time() result = native_run_default(bundle, ["hello"]) t1 = time.time() check("id native", result, "hello") print(f" time: {(t1 - t0) * 1000:.1f} ms") # Test 3: append via kernel print("\n--- Test 3: append (kernel path) ---") bundle = read_bundle("append.arboricx") t0 = time.time() result = kernel_run(bundle, ["Hello, ", "world!"]) t1 = time.time() check("append kernel", result, "Hello, world!") print(f" time: {(t1 - t0) * 1000:.1f} ms") # Test 4: append via native print("\n--- Test 4: append (native path) ---") t0 = time.time() result = native_run_default(bundle, ["Hello, ", "world!"]) t1 = time.time() check("append native", result, "Hello, world!") print(f" time: {(t1 - t0) * 1000:.1f} ms") # Test 5: append via native named export print("\n--- Test 5: append via named export 'root' ---") t0 = time.time() result = native_run_named(bundle, "append", ["Hello, ", "world!"]) t1 = time.time() check("append named", result, "Hello, world!") print(f" time: {(t1 - t0) * 1000:.1f} ms") # Test 6: true / false via native print("\n--- Test 6: true / false (native path) ---") for name, expected in [("true.arboricx", True), ("false.arboricx", False)]: bundle = read_bundle(name) buf = c_bytes(bundle) term = lib.arb_load_bundle_default(ctx, buf, len(bundle)) result = lib.arb_reduce(ctx, term, 1_000_000_000) check(f"{name} bool", to_bool(ctx, result), expected) # Test 7: number roundtrip print("\n--- Test 7: number roundtrip ---") num_tree = lib.arb_of_number(ctx, 42) check("number 42", to_number(ctx, num_tree), 42) # Test 8: string roundtrip print("\n--- Test 8: string roundtrip ---") str_tree = lib.arb_of_string(ctx, b"hello") check("string hello", to_string(ctx, str_tree), "hello") lib.arboricx_free(ctx) if all_ok: print("\nAll tests passed!") sys.exit(0) else: print("\nSome tests failed!") sys.exit(1)