Skip to content

jarp.tree.prelude

PyTree-aware wrappers for callables, proxies, and support registrations.

This subpackage contains helper wrappers such as Partial and PyTreeProxy, plus the one-time registrations used by filtered JIT for bound methods and Warp arrays.

Classes:

  • Partial

    Store a partially applied callable as a PyTree-aware proxy.

  • PyTreeProxy

    Wrap an arbitrary object and flatten the wrapped value as a PyTree.

Functions:

  • partial

    Partially apply a callable and keep the result compatible with JAX trees.

  • register_pytree_prelude

    Register the built-in PyTree adapters used by jarp.

Partial

Partial(
    func: Callable[..., T], /, *args: Any, **kwargs: Any
)

Bases: PartialCallableObjectProxy


              flowchart TD
              jarp.tree.prelude.Partial[Partial]

              

              click jarp.tree.prelude.Partial href "" "jarp.tree.prelude.Partial"
            

Store a partially applied callable as a PyTree-aware proxy.

Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.

Methods:

Attributes:

Source code in src/jarp/tree/prelude/_partial.py
def __init__(self, func: Callable[..., T], /, *args: Any, **kwargs: Any) -> None:
    """Create a proxy that records bound arguments for PyTree flattening."""
    super().__init__(func, *args, **kwargs)
    self._self_args = args
    self._self_kwargs = kwargs

__wrapped__ instance-attribute

__wrapped__: Callable[..., T]

__call__

__call__(*args: P.args, **kwargs: P.kwargs) -> T
Source code in src/jarp/tree/prelude/_partial.py
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...

PyTreeProxy

Bases: BaseObjectProxy


              flowchart TD
              jarp.tree.prelude.PyTreeProxy[PyTreeProxy]

              

              click jarp.tree.prelude.PyTreeProxy href "" "jarp.tree.prelude.PyTreeProxy"
            

Wrap an arbitrary object and flatten the wrapped value as a PyTree.

The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.

Attributes:

__wrapped__ instance-attribute

__wrapped__: T

partial

partial[T](
    func: Callable[..., T], /, *args: Any, **kwargs: Any
) -> Partial[..., T]

Partially apply a callable and keep the result compatible with JAX trees.

Source code in src/jarp/tree/prelude/_partial.py
def partial[T](func: Callable[..., T], /, *args: Any, **kwargs: Any) -> Partial[..., T]:
    """Partially apply a callable and keep the result compatible with JAX trees."""
    return Partial(func, *args, **kwargs)

register_pytree_prelude cached

register_pytree_prelude() -> None

Register the built-in PyTree adapters used by jarp.

This function is idempotent. It currently registers bound methods and warp.array so they participate correctly in tree traversals.

Source code in src/jarp/tree/prelude/_prelude.py
@functools.cache  # run only once
def register_pytree_prelude() -> None:
    """Register the built-in PyTree adapters used by jarp.

    This function is idempotent. It currently registers bound methods and
    ``warp.array`` so they participate correctly in tree traversals.
    """
    register_pytree_method()
    register_warp_array()