jarp.tree.prelude
¶
PyTree-aware wrappers for callables, proxies, and support registrations.
This subpackage contains helper wrappers such as
Partial and PyTreeProxy, plus
the one-time registrations used by filtered JIT for bound methods and Warp
arrays.
Classes:
-
Partial–Store a partially applied callable as a PyTree-aware proxy.
-
PyTreeProxy–Wrap an arbitrary object and flatten the wrapped value as a PyTree.
Functions:
-
partial–Partially apply a callable and keep the result compatible with JAX trees.
-
register_pytree_prelude–Register the built-in PyTree adapters used by jarp.
Partial
¶
Bases: PartialCallableObjectProxy
flowchart TD
jarp.tree.prelude.Partial[Partial]
click jarp.tree.prelude.Partial href "" "jarp.tree.prelude.Partial"
Store a partially applied callable as a PyTree-aware proxy.
Bound arguments and keyword arguments flatten as PyTree children, while the wrapped callable itself is partitioned between dynamic data and static metadata when needed.
Methods:
-
__call__–
Attributes:
-
__wrapped__(Callable[..., T]) –
Source code in src/jarp/tree/prelude/_partial.py
PyTreeProxy
¶
Bases: BaseObjectProxy
flowchart TD
jarp.tree.prelude.PyTreeProxy[PyTreeProxy]
click jarp.tree.prelude.PyTreeProxy href "" "jarp.tree.prelude.PyTreeProxy"
Wrap an arbitrary object and flatten the wrapped value as a PyTree.
The proxy itself stays transparent while JAX sees the wrapped object's PyTree structure.
Attributes:
-
__wrapped__(T) –
partial
¶
Partially apply a callable and keep the result compatible with JAX trees.
register_pytree_prelude
cached
¶
Register the built-in PyTree adapters used by jarp.
This function is idempotent. It currently registers bound methods and
warp.array so they participate correctly in tree traversals.