Skip to content

liblaf.apple.energy.elastic ¤

Modules:

Classes:

  • Arap

    As-Rigid-As-Possible.

  • ArapActive

    As-Rigid-As-Possible.

  • Elastic

    Elastic(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

  • PhaceActive

    PhaceActive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

  • PhacePassive

    PhacePassive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

Arap ¤

Bases: Elastic

As-Rigid-As-Possible.

\[ \Psi = \frac{\mu}{2} \|F - R\|_F^2 = \frac{\mu}{2} (I_2 - 2 I_1 + 3) \]

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

mu property ¤

mu: Float[Array, ' c']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
24
25
26
27
28
29
30
31
32
33
34
35
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.arap_energy_density_kernel(F, self.mu)
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    hess_diag: Float[jax.Array, "cells 4 3"]
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.arap_energy_density_hess_diag_kernel(F, self.mu, dhdX)
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.arap_energy_density_hess_quad_kernel(
        F, region.scatter(p), self.mu, dhdX
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/arap/_arap.py
37
38
39
40
41
42
43
44
45
46
47
48
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.arap_first_piola_kirchhoff_stress_kernel(F, self.mu)
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])

ArapActive ¤

Bases: Elastic

As-Rigid-As-Possible.

\[ \Psi = \frac{\mu}{2} \|F - R\|_F^2 = \frac{\mu}{2} (I_2 - 2 I_1 + 3) \]

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

activation property ¤

activation: Float[Array, 'c J J']

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

mu property ¤

mu: Float[Array, ' c']

muscle_fraction property ¤

muscle_fraction: Float[Array, ' c']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.arap_active_energy_density_kernel(
        F, self.activation, self.mu, self.muscle_fraction
    )
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    hess_diag: Float[jax.Array, "cells 4 3"]
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.arap_active_energy_density_hess_diag_kernel(
        F, dhdX, self.activation, self.mu, self.muscle_fraction
    )
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.arap_active_energy_density_hess_quad_kernel(
        F, region.scatter(p), dhdX, self.activation, self.mu, self.muscle_fraction
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/arap_active/_arap_active.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.arap_active_first_piola_kirchhoff_stress_kernel(
        F, self.activation, self.mu, self.muscle_fraction
    )
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])

Elastic ¤

Bases: Energy

Elastic(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
106
107
108
109
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q"]:
    raise NotImplementedError

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
124
125
126
127
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    raise NotImplementedError

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
129
130
131
132
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q"]:
    raise NotImplementedError

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
111
112
113
114
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q J J"]:
    raise NotImplementedError

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])

PhaceActive ¤

Bases: Elastic

PhaceActive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

activation property ¤

activation: Float[Array, ' cells J J']

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

lambda_ property ¤

lambda_: Float[Array, ' cells']

mu property ¤

mu: Float[Array, ' cells']

muscle_fraction property ¤

muscle_fraction: Float[Array, ' cells']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.phace_active_energy_density_kernel(
        F, self.activation, self.lambda_, self.mu, self.muscle_fraction
    )
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.phace_active_energy_density_hess_diag_kernel(
        F, dhdX, self.activation, self.lambda_, self.mu, self.muscle_fraction
    )
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.phace_active_energy_density_hess_quad_kernel(
        F,
        region.scatter(p),
        dhdX,
        self.activation,
        self.lambda_,
        self.mu,
        self.muscle_fraction,
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/phace_active/_phace_active.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.phace_active_first_piola_kirchhoff_stress_kernel(
        F, self.activation, self.lambda_, self.mu, self.muscle_fraction
    )
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])

PhacePassive ¤

Bases: Elastic

PhacePassive(actor: liblaf.apple.sim.actor.actor.Actor, *, id: str = None, hess_diag_filter: bool = True, hess_quad_filter: bool = True)

Parameters:

  • id (str, default: None ) –
  • actor (Actor) –
  • hess_diag_filter (bool, default: True ) –
  • hess_quad_filter (bool, default: True ) –

Methods:

Attributes:

__dataclass_fields__ class-attribute ¤

__dataclass_fields__: dict[str, Field[Any]]

activation property ¤

activation: Float[Array, 'c J J']

actor class-attribute instance-attribute ¤

actor: Actor = field()

actors property ¤

hess_diag_filter class-attribute instance-attribute ¤

hess_diag_filter: bool = field(default=True, kw_only=True)

hess_quad_filter class-attribute instance-attribute ¤

hess_quad_filter: bool = field(default=True, kw_only=True)

id class-attribute instance-attribute ¤

id: str = field(default=None, kw_only=True)

lambda_ property ¤

lambda_: Float[Array, ' cells']

mu property ¤

mu: Float[Array, ' cells']

muscle_fraction property ¤

muscle_fraction: Float[Array, ' cells']

region property ¤

region: Region

__post_init__ ¤

__post_init__() -> None
Source code in src/liblaf/apple/struct/tree/_node.py
13
14
15
def __post_init__(self) -> None:
    if self.id is None:
        object.__setattr__(self, "id", uniq_id(self))

energy_density ¤

energy_density(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
29
30
31
32
33
34
35
36
37
38
39
40
@override
@utils.jit(inline=True)
def energy_density(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    Psi: Float[jax.Array, " cq"]
    (Psi,) = kernel.phace_passive_energy_density_kernel(F, self.lambda_, self.mu)
    Psi: Float[jax.Array, "c q"] = region.unsqueeze_cq(Psi)
    return Psi

energy_density_hess_diag ¤

energy_density_hess_diag(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@override
@utils.jit(inline=True)
def energy_density_hess_diag(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q a J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_diag: Float[jax.Array, "cq a J"]
    (hess_diag,) = kernel.phace_passive_energy_density_hess_diag_kernel(
        F, dhdX, self.lambda_, self.mu
    )
    hess_diag: Float[jax.Array, "c q a J"] = region.unsqueeze_cq(hess_diag)
    return hess_diag

energy_density_hess_quad ¤

energy_density_hess_quad(
    field: Field, p: Field, /, params: GlobalParams
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@override
@utils.jit(inline=True)
def energy_density_hess_quad(
    self, field: Field, p: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    dhdX: Float[jax.Array, "cq a J"] = region.squeeze_cq(region.dhdX)
    hess_quad: Float[jax.Array, " cq"]
    (hess_quad,) = kernel.phace_passive_energy_density_hess_quad_kernel(
        F, region.scatter(p), dhdX, self.lambda_, self.mu
    )
    hess_quad: Float[jax.Array, "c q"] = region.unsqueeze_cq(hess_quad)
    return hess_quad

energy_density_jac ¤

energy_density_jac(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q a J"]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
116
117
118
119
120
121
122
@utils.jit(inline=True)
def energy_density_jac(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[Array, "c q a J"]:
    PK1: Float[Array, "c q J J"] = self.first_piola_kirchhoff_stress(field, params)
    dPsidx: Float[Array, "c q a J"] = self.region.gradient_vjp(PK1)
    return dPsidx

first_piola_kirchhoff_stress ¤

first_piola_kirchhoff_stress(
    field: Field, /, params: GlobalParams
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/energy/elastic/phace_passive/_phace_passive.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@override
@utils.jit(inline=True)
def first_piola_kirchhoff_stress(
    self, field: Field, /, params: sim.GlobalParams
) -> Float[jax.Array, "c q J J"]:
    region: sim.Region = self.region
    F: Float[jax.Array, "c q J J"] = region.deformation_gradient(field)
    F: Float[jax.Array, "cq J J"] = region.squeeze_cq(F)
    PK1: Float[jax.Array, "cq J J"]
    (PK1,) = kernel.phace_passive_first_piola_kirchhoff_stress_kernel(
        F, self.lambda_, self.mu
    )
    PK1: Float[jax.Array, "c q J J"] = region.unsqueeze_cq(PK1)
    return PK1

from_actor classmethod ¤

from_actor(
    actor: Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
16
17
18
19
20
21
22
23
24
25
26
27
28
@classmethod
def from_actor(
    cls,
    actor: sim.Actor,
    *,
    hess_diag_filter: bool = True,
    hess_quad_filter: bool = True,
) -> Self:
    return cls(
        actor=actor,
        hess_diag_filter=hess_diag_filter,
        hess_quad_filter=hess_quad_filter,
    )

fun ¤

fun(
    x: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
43
44
45
46
47
48
49
@override
@utils.jit(inline=True)
def fun(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    Psi: Float[Array, "c q"] = self.energy_density(field, params)
    Psi: Float[Array, " c"] = self.region.integrate(Psi)
    return jnp.sum(Psi)

fun_and_jac ¤

fun_and_jac(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[Float[Array, ""], ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
92
93
94
95
96
97
@override
@utils.jit(inline=True)
def fun_and_jac(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[Float[Array, ""], struct.ArrayDict]:
    return self.fun(x, params), self.jac(x, params)

hess_diag ¤

hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
@utils.jit(inline=True)
def hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    hess_diag: Float[Array, "c q a J"] = self.energy_density_hess_diag(
        field, params
    )
    if self.hess_diag_filter:
        hess_diag = jnp.clip(hess_diag, min=0.0)
    # jax.debug.print("Elastic.hess_diag: {}", hess_diag)
    hess_diag: Float[Array, "c a J"] = self.region.integrate(hess_diag)
    hess_diag: Float[Array, "p J"] = self.region.gather(hess_diag)
    return struct.ArrayDict({self.actor.id: hess_diag})

hess_quad ¤

hess_quad(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> Float[Array, ""]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
@override
@utils.jit(inline=True)
def hess_quad(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: sim.GlobalParams
) -> Float[Array, ""]:
    field: Field = x[self.actor.id]
    field_p: Field = p[self.actor.id]
    hess_quad: Float[Array, "c q"] = self.energy_density_hess_quad(
        field, field_p, params
    )
    if self.hess_quad_filter:
        hess_quad = jnp.clip(hess_quad, min=0.0)
    hess_quad: Float[Array, " c"] = self.region.integrate(hess_quad)
    hess_quad: Float[Array, ""] = jnp.sum(hess_quad)
    return hess_quad

hessp ¤

hessp(
    x: ArrayDict, p: ArrayDict, /, params: GlobalParams
) -> ArrayDict
Source code in src/liblaf/apple/sim/energy/energy.py
63
64
65
66
67
68
@utils.not_implemented
@utils.jit
def hessp(
    self, x: struct.ArrayDict, p: struct.ArrayDict, /, params: GlobalParams
) -> struct.ArrayDict:
    return math.jvp(self.jac)(x, p, params)

jac ¤

jac(x: ArrayDict, /, params: GlobalParams) -> ArrayDict
Source code in src/liblaf/apple/energy/elastic/_elastic.py
51
52
53
54
55
56
57
58
@override
@utils.jit(inline=True)
def jac(self, x: struct.ArrayDict, /, params: sim.GlobalParams) -> struct.ArrayDict:
    field: Field = x[self.actor.id]
    jac: Float[Array, "c q a J"] = self.energy_density_jac(field, params)
    jac: Float[Array, "c a J"] = self.region.integrate(jac)
    jac: Float[Array, "p J"] = self.region.gather(jac)
    return struct.ArrayDict({self.actor.id: jac})

jac_and_hess_diag ¤

jac_and_hess_diag(
    x: ArrayDict, /, params: GlobalParams
) -> tuple[ArrayDict, ArrayDict]
Source code in src/liblaf/apple/energy/elastic/_elastic.py
 99
100
101
102
103
104
@override
@utils.jit(inline=True)
def jac_and_hess_diag(
    self, x: struct.ArrayDict, /, params: sim.GlobalParams
) -> tuple[struct.ArrayDict, struct.ArrayDict]:
    return self.jac(x, params), self.hess_diag(x, params)

pre_optim_iter ¤

pre_optim_iter(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
24
25
def pre_optim_iter(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_jit deprecated ¤

pre_optim_iter_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
27
28
29
30
@utils.jit(inline=True, validate=False)
@deprecated("deprecated.")
def pre_optim_iter_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_optim_iter_no_jit deprecated ¤

pre_optim_iter_no_jit(params: GlobalParams) -> Self
Deprecated

deprecated.

Source code in src/liblaf/apple/sim/energy/energy.py
32
33
34
@deprecated("deprecated.")
def pre_optim_iter_no_jit(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

pre_time_step ¤

pre_time_step(params: GlobalParams) -> Self
Source code in src/liblaf/apple/sim/energy/energy.py
21
22
def pre_time_step(self, params: GlobalParams) -> Self:  # noqa: ARG002
    return self

replace ¤

replace(**changes: Any) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
19
20
def replace(self, **changes: Any) -> Self:
    return dataclasses.replace(self, **changes)

tree_at ¤

tree_at(
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] = MISSING,
    replace_fn: Callable[[Node], Any] = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self
Source code in src/liblaf/apple/struct/tree/_pytree.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def tree_at(
    self,
    where: Callable[[Self], Node | Sequence[Node]],
    replace: Any | Sequence[Any] | MISSING = MISSING,
    replace_fn: Callable[[Node], Any] | MISSING = MISSING,
    is_leaf: Callable[[Any], bool] | None = None,
) -> Self:
    kwargs: dict[str, Any] = {}
    if replace is not MISSING:
        kwargs["replace"] = replace
    if replace_fn is not MISSING:
        kwargs["replace_fn"] = replace_fn
    if is_leaf is not None:
        kwargs["is_leaf"] = is_leaf
    return eqx.tree_at(where, self, **kwargs)

with_actors ¤

with_actors(actors: NodeContainer[Actor]) -> Self
Source code in src/liblaf/apple/energy/elastic/_elastic.py
35
36
37
@override
def with_actors(self, actors: struct.NodeContainer[sim.Actor]) -> Self:
    return self.replace(actor=actors[self.actor.id])