PyTree Workflows¶
This guide covers the pieces of liblaf.jarp that make mixed data-and-metadata
trees behave predictably under JAX.
Choose Field Behavior Explicitly¶
import jax.numpy as jnp
import jax.tree_util as jtu
from liblaf import jarp
@jarp.define
class Example:
data: object = jarp.array(default=0.0)
label: str = jarp.static(default="")
extra: object = jarp.auto(default="")
obj = Example()
leaves, _ = jtu.tree_flatten(obj)
obj.extra = jnp.zeros(())
leaves_with_extra, _ = jtu.tree_flatten(obj)
data always flattens as a JAX child. label always stays static. extra
follows the runtime value: a string stays static, while an array becomes a
dynamic child. The runtime check is the same one exposed by is_data.
Flatten A Tree Once And Reuse Its Structure¶
import jax.numpy as jnp
from liblaf import jarp
payload = {"a": jnp.zeros((3,)), "b": jnp.ones((4,)), "static": "foo"}
flat, structure = jarp.ravel(payload)
same_shape = {"a": jnp.ones((3,)), "b": jnp.zeros((4,)), "static": "foo"}
flat_again = structure.ravel(same_shape)
round_trip = structure.unravel(flat)
Use ravel when an optimizer, solver, or serialization step wants one vector
without losing the tree layout or static leaves. The returned Structure can
flatten another compatible tree later or rebuild the original layout from a
flat vector.
Carry Enum State As Dynamic Data¶
import enum
import jax
import jax.numpy as jnp
from liblaf import jarp
class Phase(jarp.Enum):
START = enum.auto()
RUNNING = enum.auto()
DONE = enum.auto()
@jax.jit
def choose(mask):
return Phase.where(mask, Phase.START, Phase.RUNNING)
phase = choose(jnp.array([True, False, True]))
Enum stores its integer value as a dynamic JAX leaf, so enum state can be
carried through jax.jit, jax.lax.while_loop, and other PyTree-aware APIs.
Scalar values still resolve to ordinary enum members. Vectorized choices can
produce an enum object named "<unknown>" because the array may hold several
member values at once.
Use Enum.where for a two-way choice and Enum.select for ordered condition
lists. The lower-level tree.where and tree.select apply the same leafwise
selection to any matching PyTree structure.
Wrap Foreign Objects As PyTrees¶
import jax
import jax.numpy as jnp
from liblaf import jarp
proxy = jarp.PyTreeProxy((jnp.zeros(()), "static"))
leaves, treedef = jax.tree.flatten(proxy)
restored = jax.tree.unflatten(treedef, leaves)
PyTreeProxy keeps the wrapper transparent while JAX traverses the wrapped
value. partial provides the same idea for partially applied callables whose
bound arguments should remain visible to tree traversals.
Importing jarp.tree also imports the private prelude module that registers
bound methods and warp.array with JAX. That means methods and Warp arrays are
ready before the public partitioning, raveling, and wrapper helpers need them.
Key-aware traversals report useful paths for these adapters: bound methods use
__self__, and partial exposes _self_args, _self_kwargs, and
__wrapped__.
Register Classes Without jarp.define¶
Use register_fieldz when an attrs class already carries the right field
metadata. Use register_generic when a class does not come from attrs or
when you want to spell out which fields are always data, always metadata, or
filtered at runtime.
register_generic builds specialized flatten and unflatten callbacks, and it
can bypass custom __setattr__ implementations when needed during unflatten.
See the API reference for jarp.tree,
jarp.tree.prelude, and
jarp.tree.codegen for the
exact registration API and generated callback helpers.