Skip to content

liblaf.apple.energy ¤

Modules:

Classes:

  • Arap

    As-Rigid-As-Possible.

  • ArapActive

    As-Rigid-As-Possible.

  • CollisionCandidatesVertFace

    CollisionCandidatesVertFace(closest: jaxtyping.Float[Array, 'points 3'] = None, collide: jaxtyping.Bool[Array, 'points'] = None, distance: jaxtyping.Float[Array, 'points'] = None, face_id: jaxtyping.Integer[Array, 'points'] = None, face_normal: jaxtyping.Float[Array, 'points 3'] = None, uv: jaxtyping.Float[Array, 'points 2'] = None)

  • CollisionVertFace

    CollisionVertFace(rigid: liblaf.apple.sim.actor.actor.Actor = None, soft: liblaf.apple.sim.actor.actor.Actor = None, stiffness: jaxtyping.Float[Array, ''] = 100000.0, rest_length: jaxtyping.Float[Array, ''] = 0.001, max_dist: jaxtyping.Float[Array, ''] = 0.01, epsilon: jaxtyping.Float[Array, ''] = 0.001, candidates: liblaf.apple.energy.collision.vert_face.vert_face.CollisionCandidatesVertFace = , *, id: str = None, filter_hess_diag: bool = True, filter_hess_quad: bool = True)

  • EnergyZero

    EnergyZero(_actors: liblaf.apple.struct.dictutils._node_container.NodeContainer[liblaf.apple.sim.actor.actor.Actor] = , *, id: str = None)

  • PhaceActive

    PhaceActive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

  • PhacePassive

    PhacePassive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

Arap ¤

Bases: Elastic

As-Rigid-As-Possible.

\[ \Psi = \frac{\mu}{2} \|F - R\|_F^2 = \frac{\mu}{2} (I_2 - 2 I_1 + 3) \]

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

mu property ¤

mu: Float[Array, ' c']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
24
25
26
27
28
29
30
31
32
33
34
35
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.arap_energy_density_kernel(F, self.mu)
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    hess_diag: Float[jax.Array, "cells 4 3"]
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.arap_energy_density_hess_diag_kernel(F, self.mu, dhdX)
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.arap_energy_density_hess_quad_kernel(
        F, region.scatter(p), self.mu, dhdX
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
37
38
39
40
41
42
43
44
45
46
47
48
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.arap_first_piola_kirchhoff_stress_kernel(F, self.mu)
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])

ArapActive ¤

Bases: Elastic

As-Rigid-As-Possible.

\[ \Psi = \frac{\mu}{2} \|F - R\|_F^2 = \frac{\mu}{2} (I_2 - 2 I_1 + 3) \]

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

activation property ¤

activation: Float[Array, 'c J J']

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

mu property ¤

mu: Float[Array, ' c']

muscle_fraction property ¤

muscle_fraction: Float[Array, ' c']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.arap_active_energy_density_kernel(
        F, self.activation, self.mu, self.muscle_fraction
    )
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    hess_diag: Float[jax.Array, "cells 4 3"]
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.arap_active_energy_density_hess_diag_kernel(
        F, dhdX, self.activation, self.mu, self.muscle_fraction
    )
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.arap_active_energy_density_hess_quad_kernel(
        F, region.scatter(p), dhdX, self.activation, self.mu, self.muscle_fraction
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.arap_active_first_piola_kirchhoff_stress_kernel(
        F, self.activation, self.mu, self.muscle_fraction
    )
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])

CollisionCandidatesVertFace ¤

Bases: PyTree

CollisionCandidatesVertFace(closest: jaxtyping.Float[Array, 'points 3'] = None, collide: jaxtyping.Bool[Array, 'points'] = None, distance: jaxtyping.Float[Array, 'points'] = None, face_id: jaxtyping.Integer[Array, 'points'] = None, face_normal: jaxtyping.Float[Array, 'points 3'] = None, uv: jaxtyping.Float[Array, 'points 2'] = None)

Parameters:

  • closest (Float[Array, 'points 3'], default: None ) –
  • collide (Bool[Array, points], default: None ) –
  • distance (Float[Array, points], default: None ) –
  • face_id (Integer[Array, points], default: None ) –
  • face_normal (Float[Array, 'points 3'], default: None ) –
  • uv (Float[Array, 'points 2'], default: None ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

closest class-attribute instance-attribute ¤

closest: Float[Array, 'points 3'] = array(default=None)

collide class-attribute instance-attribute ¤

collide: Bool[Array, ' points'] = array(default=None)

distance class-attribute instance-attribute ¤

distance: Float[Array, ' points'] = array(default=None)

face_id class-attribute instance-attribute ¤

face_id: Integer[Array, ' points'] = array(default=None)

face_normal class-attribute instance-attribute ¤

face_normal: Float[Array, 'points 3'] = array(default=None)

uv class-attribute instance-attribute ¤

uv: Float[Array, 'points 2'] = array(default=None)

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

CollisionVertFace ¤

Bases: Energy

CollisionVertFace(rigid: liblaf.apple.sim.actor.actor.Actor = None, soft: liblaf.apple.sim.actor.actor.Actor = None, stiffness: jaxtyping.Float[Array, ''] = 100000.0, rest_length: jaxtyping.Float[Array, ''] = 0.001, max_dist: jaxtyping.Float[Array, ''] = 0.01, epsilon: jaxtyping.Float[Array, ''] = 0.001, candidates: liblaf.apple.energy.collision.vert_face.vert_face.CollisionCandidatesVertFace = , *, id: str = None, filter_hess_diag: bool = True, filter_hess_quad: bool = True)

Parameters:

  • id (str, default: None ) –
  • rigid (Actor, default: None ) –
  • soft (Actor, default: None ) –
  • stiffness (Float[Array, ''], default: 100000.0 ) –
  • rest_length (Float[Array, ''], default: 0.001 ) –
  • max_dist (Float[Array, ''], default: 0.01 ) –
  • epsilon (Float[Array, ''], default: 0.001 ) –
  • filter_hess_diag (bool, default: True ) –
  • filter_hess_quad (bool, default: True ) –
  • candidates (CollisionCandidatesVertFace, default: <dynamic> ) –

    CollisionCandidatesVertFace(closest: jaxtyping.Float[Array, 'points 3'] = None, collide: jaxtyping.Bool[Array, 'points'] = None, distance: jaxtyping.Float[Array, 'points'] = None, face_id: jaxtyping.Integer[Array, 'points'] = None, face_normal: jaxtyping.Float[Array, 'points 3'] = None, uv: jaxtyping.Float[Array, 'points 2'] = None)

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

actors property ¤

candidates class-attribute instance-attribute ¤

epsilon class-attribute instance-attribute ¤

epsilon: Float[Array, ''] = array(default=0.001)

filter_hess_diag class-attribute instance-attribute ¤

filter_hess_diag: bool = field(default=True, kw_only=True)

filter_hess_quad class-attribute instance-attribute ¤

filter_hess_quad: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

max_dist class-attribute instance-attribute ¤

max_dist: Float[Array, ''] = array(default=0.01)

rest_length class-attribute instance-attribute ¤

rest_length: Float[Array, ''] = array(default=0.001)

rigid class-attribute instance-attribute ¤

rigid: Actor = field(default=None)

soft class-attribute instance-attribute ¤

soft: Actor = field(default=None)

stiffness class-attribute instance-attribute ¤

stiffness: Float[Array, ''] = array(default=100000.0)

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

collide ¤

Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
@utils.jit(inline=True)
def collide(self) -> CollisionCandidatesVertFace:
    (
        closest,
        collide,
        distance,
        face_id,
        face_normal,
        uv,
    ) = collision_detect_vert_face_kernel(
        self.soft.positions,
        np.uint64(self.rigid.collision_mesh.id),
        self.rest_length.reshape((1,)),
        self.max_dist.reshape((1,)),
        self.epsilon.reshape((1,)),
        output_dims={
            "closest": (self.soft.n_points,),
            "collide": (self.soft.n_points,),
            "distance": (self.soft.n_points,),
            "face_id": (self.soft.n_points,),
            "face_normal": (self.soft.n_points,),
            "uv": (self.soft.n_points,),
        },
        launch_dims=(self.soft.n_points,),
    )
    return CollisionCandidatesVertFace(
        closest=closest,
        collide=collide,
        distance=distance,
        face_id=face_id,
        face_normal=face_normal,
        uv=uv,
    )

from_actors classmethod ¤

from_actors(
    rigid: Actor,
    soft: Actor,
    *,
    stiffness: float = 1000.0,
    rest_length: float = 0.001,
    max_dist: float | None = None,
    epsilon: float = 0.001,
) -> Self
Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@classmethod
def from_actors(
    cls,
    rigid: sim.Actor,
    soft: sim.Actor,
    *,
    stiffness: float = 1e3,
    rest_length: float = 1e-3,
    max_dist: float | None = None,
    epsilon: float = 1e-3,
) -> Self:
    if max_dist is None:
        max_dist = 2.0 * rest_length
    return cls(
        rigid=rigid,
        soft=soft,
        stiffness=jnp.asarray(stiffness),
        rest_length=jnp.asarray(rest_length),
        max_dist=jnp.asarray(max_dist),
        epsilon=jnp.asarray(epsilon),
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    points: Float[Array, "points dim"] = self.soft.points + x[self.soft.id]
    energy: Float[Array, " points"]
    (energy,) = collision_energy_vert_face_fun_kernel(
        points,
        self.candidates.closest,
        self.candidates.collide,
        self.candidates.distance,
        self.rest_length.reshape((1,)),
        self.stiffness.reshape((1,)),
        output_dims={"energy": (1,)},
        launch_dims=(self.soft.n_points,),
    )
    return energy.sum()

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/sim/energy/energy.py
87
88
89
90
91
92
@utils.not_implemented
@utils.jit
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    points: Float[Array, "points dim"] = self.soft.points + x[self.soft.id]
    hess_diag: Float[Array, "points dim"]
    (hess_diag,) = collision_energy_vert_face_hess_diag_kernel(
        points,
        self.candidates.closest,
        self.candidates.collide,
        self.candidates.distance,
        self.rest_length.reshape((1,)),
        self.stiffness.reshape((1,)),
        output_dims={"hess_diag": (self.soft.n_points,)},
        launch_dims=(self.soft.n_points,),
    )
    if self.filter_hess_diag:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    return struct.ArrayDict(
        {self.soft.id: hess_diag, self.rigid.id: jnp.zeros_like(x[self.rigid.id])}
    )

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    points: Float[Array, "points dim"] = self.soft.points + x[self.soft.id]
    hess_quad: Float[Array, " points"]
    (hess_quad,) = collision_energy_vert_face_hess_quad_kernel(
        points,
        p[self.soft.id],
        self.candidates.closest,
        self.candidates.collide,
        self.candidates.distance,
        self.rest_length.reshape((1,)),
        self.stiffness.reshape((1,)),
        output_dims={"hess_quad": (self.soft.n_points,)},
        launch_dims=(self.soft.n_points,),
    )
    if self.filter_hess_quad:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    return hess_quad.sum()

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    points: Float[Array, "points dim"] = self.soft.points + x[self.soft.id]
    jac_soft: Float[Array, " points dim"]
    (jac_soft,) = collision_energy_vert_face_jac_kernel(
        points,
        self.candidates.closest,
        self.candidates.collide,
        self.candidates.distance,
        self.rest_length.reshape((1,)),
        self.stiffness.reshape((1,)),
        output_dims={"jac": (self.soft.n_points,)},
        launch_dims=(self.soft.n_points,),
    )
    # jax.debug.print("CollisionVertFace.jac = {}", jac_soft)
    return struct.ArrayDict(
        {self.soft.id: jac_soft, self.rigid.id: jnp.zeros_like(x[self.rigid.id])}
    )

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/sim/energy/energy.py
94
95
96
97
98
99
@utils.not_implemented
@utils.jit
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
73
74
75
76
@override
def pre_optim_iter(self, params: sim.GlobalParams) -> Self:
    candidates: CollisionCandidatesVertFace = self.collide()
    return self.replace(candidates=candidates)

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/collision/vert_face/vert_face.py
69
70
71
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(rigid=actors[self.rigid.id], soft=actors[self.soft.id])

EnergyZero ¤

Bases: Energy

EnergyZero(_actors: liblaf.apple.struct.dictutils._node_container.NodeContainer[liblaf.apple.sim.actor.actor.Actor] = , *, id: str = None)

Parameters:

  • id (str, default: None ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

actors property ¤

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

from_actors classmethod ¤

from_actors(*actors: Actor) -> Self
Source code in src/liblaf/apple/energy/zero.py
17
18
19
@classmethod
def from_actors(cls, *actors: sim.Actor) -> Self:
    return cls(_actors=struct.NodeContainer(actors))

fun ¤

fun(x: ArrayDict, /, params: GlobalParams) -> FloatScalar
Source code in src/liblaf/apple/energy/zero.py
30
31
32
33
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> FloatScalar:
    return jnp.zeros(())

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/sim/energy/energy.py
87
88
89
90
91
92
@utils.not_implemented
@utils.jit
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/zero.py
42
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    return struct.ArrayDict(
        {actor.id: jnp.zeros_like(x[actor.id]) for actor in self.actors.values()}
    )

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> FloatScalar
Source code in src/liblaf/apple/energy/zero.py
51
52
53
54
55
56
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> FloatScalar:
    return jnp.zeros(())

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/zero.py
35
36
37
38
39
40
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    return struct.ArrayDict(
        {actor.id: jnp.zeros_like(x[actor.id]) for actor in self.actors.values()}
    )

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/sim/energy/energy.py
94
95
96
97
98
99
@utils.not_implemented
@utils.jit
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/zero.py
26
27
28
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(_actors=actors)

PhaceActive ¤

Bases: Elastic

PhaceActive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

activation property ¤

activation: Float[Array, ' cells J J']

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

lambda_ property ¤

lambda_: Float[Array, ' cells']

mu property ¤

mu: Float[Array, ' cells']

muscle_fraction property ¤

muscle_fraction: Float[Array, ' cells']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.phace_active_energy_density_kernel(
        F, self.activation, self.lambda_, self.mu, self.muscle_fraction
    )
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.phace_active_energy_density_hess_diag_kernel(
        F, dhdX, self.activation, self.lambda_, self.mu, self.muscle_fraction
    )
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.phace_active_energy_density_hess_quad_kernel(
        F,
        region.scatter(p),
        dhdX,
        self.activation,
        self.lambda_,
        self.mu,
        self.muscle_fraction,
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.phace_active_first_piola_kirchhoff_stress_kernel(
        F, self.activation, self.lambda_, self.mu, self.muscle_fraction
    )
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])

PhacePassive ¤

Bases: Elastic

PhacePassive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

activation property ¤

activation: Float[Array, 'c J J']

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

lambda_ property ¤

lambda_: Float[Array, ' cells']

mu property ¤

mu: Float[Array, ' cells']

muscle_fraction property ¤

muscle_fraction: Float[Array, ' cells']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
29
30
31
32
33
34
35
36
37
38
39
40
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.phace_passive_energy_density_kernel(F, self.lambda_, self.mu)
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.phace_passive_energy_density_hess_diag_kernel(
        F, dhdX, self.lambda_, self.mu
    )
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.phace_passive_energy_density_hess_quad_kernel(
        F, region.scatter(p), dhdX, self.lambda_, self.mu
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.phace_passive_first_piola_kirchhoff_stress_kernel(
        F, self.lambda_, self.mu
    )
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])