jarp¶
jarp keeps mixed JAX PyTrees usable across function boundaries, attrs, and
NVIDIA Warp. The package is intentionally small: it focuses on filtered call
wrappers, PyTree-friendly class definitions, round-trippable flattening, a
handful of control-flow wrappers, and Warp adapters that fit JAX-first
workflows.
import jax.numpy as jnp
import jarp
@jarp.define
class Batch:
values: object = jarp.array()
label: str = jarp.static()
@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
centered = batch.values - jnp.mean(batch.values)
return Batch(values=centered, label=batch.label)
Read By Workflow¶
- Getting started for installation, the core field specifiers, and the first filtered call wrapper.
- Call wrappers for
filter_jit,fallback_jit, and thejarp.laxhelpers. - PyTree workflows for
define,auto,ravel,PyTreeProxy, and custom registration helpers. - Warp interop for
to_warp, generic Warp adapters, and dtype helpers. - API reference for exact signatures and module-level details.
- Benchmarks for the current wrapper-overhead and PyTree registration measurements.