Skip to content

jarp

jarp keeps mixed JAX PyTrees usable across function boundaries, attrs, and NVIDIA Warp. The package is intentionally small: it focuses on filtered call wrappers, PyTree-friendly class definitions, round-trippable flattening, a handful of control-flow wrappers, and Warp adapters that fit JAX-first workflows.

import jax.numpy as jnp
import jarp


@jarp.define
class Batch:
    values: object = jarp.array()
    label: str = jarp.static()


@jarp.filter_jit
def normalize(batch: Batch) -> Batch:
    centered = batch.values - jnp.mean(batch.values)
    return Batch(values=centered, label=batch.label)

Read By Workflow

  • Getting started for installation, the core field specifiers, and the first filtered call wrapper.
  • Call wrappers for filter_jit, fallback_jit, and the jarp.lax helpers.
  • PyTree workflows for define, auto, ravel, PyTreeProxy, and custom registration helpers.
  • Warp interop for to_warp, generic Warp adapters, and dtype helpers.
  • API reference for exact signatures and module-level details.
  • Benchmarks for the current wrapper-overhead and PyTree registration measurements.