Skip to content

PyTree Workflows

This guide covers the pieces of jarp that make mixed data-and-metadata trees behave predictably under JAX.

Choose Field Behavior Explicitly

import jax.numpy as jnp
import jax.tree_util as jtu
import jarp


@jarp.define
class Example:
    data: object = jarp.array(default=0.0)
    label: str = jarp.static(default="")
    extra: object = jarp.auto(default="")


obj = Example()
leaves, _ = jtu.tree_flatten(obj)

obj.extra = jnp.zeros(())
leaves_with_extra, _ = jtu.tree_flatten(obj)

data always flattens as a JAX child. label always stays static. extra follows the runtime value: a string stays static, while an array becomes a dynamic child. The runtime check is the same one exposed by is_data.

Flatten A Tree Once And Reuse Its Structure

import jax.numpy as jnp
import jarp


payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)

same_shape = {"a": jnp.ones((3,)), "b": jnp.zeros((4,)), "static": "foo"}
flat_again = structure.ravel(same_shape)
round_trip = structure.unravel(flat)

Use ravel when an optimizer, solver, or serialization step wants one vector without losing the tree layout or static leaves. The returned Structure can flatten another compatible tree later or rebuild the original layout from a flat vector.

Wrap Foreign Objects As PyTrees

import jax
import jax.numpy as jnp
import jarp


proxy = jarp.PyTreeProxy((jnp.zeros(()), "static"))
leaves, treedef = jax.tree.flatten(proxy)
restored = jax.tree.unflatten(treedef, leaves)

PyTreeProxy keeps the wrapper transparent while JAX traverses the wrapped value. partial provides the same idea for partially applied callables whose bound arguments should remain visible to tree traversals.

register_pytree_prelude performs the built-in one-time registrations used by the higher-level wrappers, including bound methods and warp.array. Most users only need it when they want those registrations early.

Register Classes Without jarp.define

Use register_fieldz when an attrs class already carries the right field metadata. Use register_generic when a class does not come from attrs or when you want to spell out which fields are always data, always metadata, or filtered at runtime.

register_generic builds specialized flatten and unflatten callbacks, and it can bypass custom __setattr__ implementations when needed during unflatten.

See the API reference for jarp.tree, jarp.tree.prelude, and jarp.tree.codegen for the exact registration API and generated callback helpers.