Skip to content

liblaf.apple.jax.sim.energy ¤

Modules:

Classes:

Functions:

ARAP ¤

Bases: Elastic

Parameters:

  • id (str, default: <dynamic> ) –
  • requires_grad (Sequence[str], default: () ) –
  • region (Region) –
  • mu (Float[Array, c]) –

Methods:

Attributes:

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

mu instance-attribute ¤

mu: Float[Array, ' c']

region instance-attribute ¤

region: Region

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

energy_density ¤

energy_density(
    F: Float[Array, "c q J J"],
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/_arap.py
20
21
22
23
24
def energy_density(self, F: Float[Array, "c q J J"]) -> Float[Array, "c q"]:
    mu: Float[Array, " c #q"] = self.mu[:, jnp.newaxis]
    R: Float[Array, "c q J J"]
    R, _ = math.polar_rv(F)
    return 0.5 * mu * math.fro_norm_square(F - R)

from_geometry classmethod ¤

from_geometry(
    geometry: Geometry,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
22
23
24
25
26
27
28
29
@classmethod
def from_geometry(
    cls, geometry: Geometry, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    region: Region = Region.from_geometry(
        geometry, grad=True, quadrature=quadrature
    )
    return cls.from_region(region, **kwargs)

from_pyvista classmethod ¤

from_pyvista(
    mesh: UnstructuredGrid,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
31
32
33
34
35
36
@classmethod
def from_pyvista(
    cls, mesh: pv.UnstructuredGrid, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    geometry: Geometry = Geometry.from_pyvista(mesh)
    return cls.from_geometry(geometry, quadrature=quadrature, **kwargs)

from_region classmethod ¤

from_region(region: Region, **kwargs) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_arap.py
16
17
18
@classmethod
def from_region(cls, region: Region, **kwargs) -> Self:
    return cls(region=region, mu=region.cell_data["mu"], **kwargs)

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
38
39
40
41
42
@override
def fun(self, u: Vector) -> Scalar:
    F: Float[Array, "c q J J"] = self.region.deformation_gradient(u)
    Psi: Float[Array, "c q"] = self.energy_density(F)
    return self.region.integrate(Psi).sum()

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
40
41
42
43
44
45
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    data: Vector
    value, data = jax.value_and_grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return value, (data, index)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
23
24
25
26
def hess_diag(self, u: Vector) -> Updates:
    data: Vector = math.hess_diag(self.fun, u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
28
29
30
31
32
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    data: Vector
    _, data = jax.jvp(jax.grad(self.fun), (u,), (p,))
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
34
35
36
37
38
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    data: Vector
    index: UpdatesIndex
    data, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], data)

jac ¤

jac(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
18
19
20
21
def jac(self, u: Vector) -> Updates:
    data: Vector = jax.grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
47
48
49
50
def jac_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    jac: Updates = self.jac(u)
    hess_diag: Updates = self.hess_diag(u)
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
52
53
54
55
56
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(u, p)
    return outputs

ARAPActive ¤

Bases: Elastic

Parameters:

  • id (str, default: <dynamic> ) –
  • requires_grad (Sequence[str], default: () ) –
  • region (Region) –
  • activation (Float[Array, 'c 6']) –
  • mu (Float[Array, c]) –

Methods:

Attributes:

activation class-attribute instance-attribute ¤

activation: Float[Array, 'c 6'] = array()

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

mu class-attribute instance-attribute ¤

mu: Float[Array, ' c'] = array()

region instance-attribute ¤

region: Region

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

energy_density ¤

energy_density(
    F: Float[Array, "c q J J"],
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/_arap_active.py
32
33
34
35
36
37
38
39
40
@override
def energy_density(self, F: Float[Array, "c q J J"]) -> Float[Array, "c q"]:
    A: Float[Array, " c #q J J"] = utils.make_activation(self.activation)[
        :, jnp.newaxis, :, :
    ]
    mu: Float[Array, " c #q"] = self.mu[:, jnp.newaxis]
    R: Float[Array, "c q J J"]
    R, _ = math.polar_rv(F)
    return mu * math.fro_norm_square(F - R @ A)

from_geometry classmethod ¤

from_geometry(
    geometry: Geometry,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
22
23
24
25
26
27
28
29
@classmethod
def from_geometry(
    cls, geometry: Geometry, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    region: Region = Region.from_geometry(
        geometry, grad=True, quadrature=quadrature
    )
    return cls.from_region(region, **kwargs)

from_pyvista classmethod ¤

from_pyvista(
    mesh: UnstructuredGrid,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
31
32
33
34
35
36
@classmethod
def from_pyvista(
    cls, mesh: pv.UnstructuredGrid, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    geometry: Geometry = Geometry.from_pyvista(mesh)
    return cls.from_geometry(geometry, quadrature=quadrature, **kwargs)

from_region classmethod ¤

from_region(region: Region, **kwargs) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_arap_active.py
23
24
25
26
27
28
29
30
@classmethod
def from_region(cls, region: Region, **kwargs) -> Self:
    return cls(
        region=region,
        activation=region.cell_data["activation"],
        mu=region.cell_data["mu"],
        **kwargs,
    )

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
38
39
40
41
42
@override
def fun(self, u: Vector) -> Scalar:
    F: Float[Array, "c q J J"] = self.region.deformation_gradient(u)
    Psi: Float[Array, "c q"] = self.energy_density(F)
    return self.region.integrate(Psi).sum()

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
40
41
42
43
44
45
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    data: Vector
    value, data = jax.value_and_grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return value, (data, index)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
23
24
25
26
def hess_diag(self, u: Vector) -> Updates:
    data: Vector = math.hess_diag(self.fun, u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
28
29
30
31
32
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    data: Vector
    _, data = jax.jvp(jax.grad(self.fun), (u,), (p,))
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
34
35
36
37
38
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    data: Vector
    index: UpdatesIndex
    data, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], data)

jac ¤

jac(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
18
19
20
21
def jac(self, u: Vector) -> Updates:
    data: Vector = jax.grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
47
48
49
50
def jac_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    jac: Updates = self.jac(u)
    hess_diag: Updates = self.hess_diag(u)
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
52
53
54
55
56
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(u, p)
    return outputs

mixed_derivative_prod_activation ¤

mixed_derivative_prod_activation(
    u: Vector, p: Vector
) -> Vector
Source code in src/liblaf/apple/jax/sim/energy/elastic/_arap_active.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def mixed_derivative_prod_activation(self, u: Vector, p: Vector) -> Vector:
    def jac(q: Float[Array, "c 6"]) -> Vector:
        energy: Self = attrs.evolve(self, activation=q)
        data: UpdatesData
        index: UpdatesIndex
        data, index = energy.jac(u)
        jac: Vector = jax.ops.segment_sum(data, index, num_segments=u.shape[0])
        return jac

    vjp: Callable[[Vector], Float[Array, "c 6"]]
    _, vjp = jax.vjp(jac, self.activation)

    output: Float[Array, "c 6"]
    (output,) = vjp(p)
    return output

Elastic ¤

Bases: Energy

Parameters:

  • id (str, default: <dynamic> ) –
  • requires_grad (Sequence[str], default: () ) –
  • region (Region) –

Methods:

Attributes:

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

region instance-attribute ¤

region: Region

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

energy_density ¤

energy_density(
    F: Float[Array, "c q J J"],
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
44
45
def energy_density(self, F: Float[Array, "c q J J"]) -> Float[Array, "c q"]:
    raise NotImplementedError

from_geometry classmethod ¤

from_geometry(
    geometry: Geometry,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
22
23
24
25
26
27
28
29
@classmethod
def from_geometry(
    cls, geometry: Geometry, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    region: Region = Region.from_geometry(
        geometry, grad=True, quadrature=quadrature
    )
    return cls.from_region(region, **kwargs)

from_pyvista classmethod ¤

from_pyvista(
    mesh: UnstructuredGrid,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
31
32
33
34
35
36
@classmethod
def from_pyvista(
    cls, mesh: pv.UnstructuredGrid, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    geometry: Geometry = Geometry.from_pyvista(mesh)
    return cls.from_geometry(geometry, quadrature=quadrature, **kwargs)

from_region classmethod ¤

from_region(region: Region, **kwargs) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
18
19
20
@classmethod
def from_region(cls, region: Region, **kwargs) -> Self:
    return cls(region=region, **kwargs)

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
38
39
40
41
42
@override
def fun(self, u: Vector) -> Scalar:
    F: Float[Array, "c q J J"] = self.region.deformation_gradient(u)
    Psi: Float[Array, "c q"] = self.energy_density(F)
    return self.region.integrate(Psi).sum()

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
40
41
42
43
44
45
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    data: Vector
    value, data = jax.value_and_grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return value, (data, index)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
23
24
25
26
def hess_diag(self, u: Vector) -> Updates:
    data: Vector = math.hess_diag(self.fun, u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
28
29
30
31
32
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    data: Vector
    _, data = jax.jvp(jax.grad(self.fun), (u,), (p,))
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
34
35
36
37
38
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    data: Vector
    index: UpdatesIndex
    data, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], data)

jac ¤

jac(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
18
19
20
21
def jac(self, u: Vector) -> Updates:
    data: Vector = jax.grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
47
48
49
50
def jac_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    jac: Updates = self.jac(u)
    hess_diag: Updates = self.hess_diag(u)
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
52
53
54
55
56
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(u, p)
    return outputs

Energy ¤

Bases: IdMixin

Parameters:

  • id (str, default: <dynamic> ) –
  • requires_grad (Sequence[str], default: () ) –

Methods:

Attributes:

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
15
16
def fun(self, u: Vector) -> Scalar:
    raise NotImplementedError

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
40
41
42
43
44
45
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    data: Vector
    value, data = jax.value_and_grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return value, (data, index)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
23
24
25
26
def hess_diag(self, u: Vector) -> Updates:
    data: Vector = math.hess_diag(self.fun, u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
28
29
30
31
32
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    data: Vector
    _, data = jax.jvp(jax.grad(self.fun), (u,), (p,))
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
34
35
36
37
38
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    data: Vector
    index: UpdatesIndex
    data, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], data)

jac ¤

jac(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
18
19
20
21
def jac(self, u: Vector) -> Updates:
    data: Vector = jax.grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
47
48
49
50
def jac_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    jac: Updates = self.jac(u)
    hess_diag: Updates = self.hess_diag(u)
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
52
53
54
55
56
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(u, p)
    return outputs

Koiter ¤

Bases: Energy

Parameters:

  • id (str, default: <dynamic> ) –
  • requires_grad (Sequence[str], default: () ) –
  • alpha (Float[Array, c]) –

    Lamé’s first parameter.

  • beta (Float[Array, c]) –

    Lamé’s second parameter.

  • det_Iu (Float[Array, c]) –

    det(Iu).

  • h (Float[Array, c]) –

    Thickness.

  • Iu_inv (Float[Array, 'c 2 2']) –

    Inverse of the midsurface first fundamental form.

  • pre_strain (Float[Array, c]) –
  • geometry (Geometry) –

Methods:

Attributes:

  • Iu_inv (Float[Array, 'c 2 2']) –

    Inverse of the midsurface first fundamental form.

  • alpha (Float[Array, ' c']) –

    Lamé’s first parameter.

  • beta (Float[Array, ' c']) –

    Lamé’s second parameter.

  • det_Iu (Float[Array, ' c']) –

    det(Iu).

  • geometry (Geometry) –
  • h (Float[Array, ' c']) –

    Thickness.

  • id (str) –
  • pre_strain (Float[Array, ' c']) –
  • requires_grad (Sequence[str]) –

Iu_inv class-attribute instance-attribute ¤

Iu_inv: Float[Array, 'c 2 2'] = array()

Inverse of the midsurface first fundamental form.

alpha class-attribute instance-attribute ¤

alpha: Float[Array, ' c'] = array()

Lamé’s first parameter.

beta class-attribute instance-attribute ¤

beta: Float[Array, ' c'] = array()

Lamé’s second parameter.

det_Iu class-attribute instance-attribute ¤

det_Iu: Float[Array, ' c'] = array()

det(Iu).

geometry class-attribute instance-attribute ¤

geometry: Geometry = field()

h class-attribute instance-attribute ¤

h: Float[Array, ' c'] = array()

Thickness.

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

pre_strain class-attribute instance-attribute ¤

pre_strain: Float[Array, ' c'] = array()

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

from_geometry classmethod ¤

from_geometry(geometry: Geometry) -> Self
Source code in src/liblaf/apple/jax/sim/energy/_koiter.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@classmethod
def from_geometry(cls, geometry: Geometry) -> Self:
    alpha: Float[Array, " c"] = math.asarray(
        geometry.cell_data["alpha"], dtype=float
    )
    beta: Float[Array, " c"] = math.asarray(geometry.cell_data["beta"], dtype=float)
    h: Float[Array, " c"] = math.asarray(geometry.cell_data["h"], dtype=float)
    Iu: Float[Array, "c 2 2"] = _first_fundamental_form(
        geometry.points[geometry.cells_global]
    )
    pre_strain: Float[Array, " c"] = math.asarray(
        geometry.cell_data["pre-strain"], dtype=float
    )
    self: Self = cls(
        alpha=alpha,
        beta=beta,
        det_Iu=jnp.linalg.det(Iu),
        Iu_inv=jnp.linalg.inv(Iu),
        h=h,
        pre_strain=pre_strain,
        geometry=geometry,
    )
    return self

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_koiter.py
58
59
60
61
62
63
64
65
66
67
68
69
70
def fun(self, u: Vector) -> Scalar:
    I: Float[Array, "c 2 2"] = _first_fundamental_form(  # noqa: E741
        u[self.geometry.cells_global]
        + self.geometry.points[self.geometry.cells_global]
    )
    M: Float[Array, "c 2 2"] = (
        jnp.matmul(self.Iu_inv, I)
        - self.pre_strain[:, jnp.newaxis, jnp.newaxis]
        * jnp.eye(2)[jnp.newaxis, ...]
    )
    Ws: Float[Array, ""] = self._norm_SV(M)
    E: Float[Array, " c"] = 0.5 * (0.25 * self.h * Ws) * jnp.sqrt(self.det_Iu)
    return jnp.sum(E)

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
40
41
42
43
44
45
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    data: Vector
    value, data = jax.value_and_grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return value, (data, index)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
23
24
25
26
def hess_diag(self, u: Vector) -> Updates:
    data: Vector = math.hess_diag(self.fun, u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
28
29
30
31
32
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    data: Vector
    _, data = jax.jvp(jax.grad(self.fun), (u,), (p,))
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
34
35
36
37
38
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    data: Vector
    index: UpdatesIndex
    data, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], data)

jac ¤

jac(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
18
19
20
21
def jac(self, u: Vector) -> Updates:
    data: Vector = jax.grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
47
48
49
50
def jac_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    jac: Updates = self.jac(u)
    hess_diag: Updates = self.hess_diag(u)
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
52
53
54
55
56
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(u, p)
    return outputs

PhaceActive ¤

Bases: Elastic

Parameters:

  • id (str, default: <dynamic> ) –
  • requires_grad (Sequence[str], default: () ) –
  • region (Region) –
  • activation (Float[Array, 'c J J']) –
  • lambda_ (Float[Array, c]) –
  • mu (Float[Array, c]) –

Methods:

Attributes:

activation class-attribute instance-attribute ¤

activation: Float[Array, 'c J J'] = array()

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

lambda_ class-attribute instance-attribute ¤

lambda_: Float[Array, ' c'] = array()

mu class-attribute instance-attribute ¤

mu: Float[Array, ' c'] = array()

region instance-attribute ¤

region: Region

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

energy_density ¤

energy_density(
    F: Float[Array, "c q J J"],
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/_phace_active.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@override
def energy_density(self, F: Float[Array, "c q J J"]) -> Float[Array, "c q"]:
    A: Float[Array, " c #q J J"] = utils.make_activation(self.activation)[
        :, jnp.newaxis, :, :
    ]
    lambda_: Float[Array, " c #q"] = self.lambda_[:, jnp.newaxis]
    mu: Float[Array, " c #q"] = self.mu[:, jnp.newaxis]
    R: Float[Array, "c q J J"]
    R, _ = math.polar_rv(F)
    J: Float[Array, "c q"] = jnp.linalg.det(F)
    Psi_ARAP: Float[Array, "c q"] = mu * math.fro_norm_square(F - R @ A)
    Psi_volume_preserving: Float[Array, "c q"] = lambda_ * (J - 1.0) ** 2
    Psi: Float[Array, "c q"] = Psi_ARAP + Psi_volume_preserving
    return Psi

from_geometry classmethod ¤

from_geometry(
    geometry: Geometry,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
22
23
24
25
26
27
28
29
@classmethod
def from_geometry(
    cls, geometry: Geometry, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    region: Region = Region.from_geometry(
        geometry, grad=True, quadrature=quadrature
    )
    return cls.from_region(region, **kwargs)

from_pyvista classmethod ¤

from_pyvista(
    mesh: UnstructuredGrid,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
31
32
33
34
35
36
@classmethod
def from_pyvista(
    cls, mesh: pv.UnstructuredGrid, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    geometry: Geometry = Geometry.from_pyvista(mesh)
    return cls.from_geometry(geometry, quadrature=quadrature, **kwargs)

from_region classmethod ¤

from_region(region: Region, **kwargs) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_phace_active.py
24
25
26
27
28
29
30
31
32
33
@override
@classmethod
def from_region(cls, region: Region, **kwargs) -> Self:
    return cls(
        region=region,
        activation=region.cell_data["activation"],
        lambda_=region.cell_data["lambda"],
        mu=region.cell_data["mu"],
        **kwargs,
    )

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
38
39
40
41
42
@override
def fun(self, u: Vector) -> Scalar:
    F: Float[Array, "c q J J"] = self.region.deformation_gradient(u)
    Psi: Float[Array, "c q"] = self.energy_density(F)
    return self.region.integrate(Psi).sum()

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
40
41
42
43
44
45
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    data: Vector
    value, data = jax.value_and_grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return value, (data, index)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
23
24
25
26
def hess_diag(self, u: Vector) -> Updates:
    data: Vector = math.hess_diag(self.fun, u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
28
29
30
31
32
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    data: Vector
    _, data = jax.jvp(jax.grad(self.fun), (u,), (p,))
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
34
35
36
37
38
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    data: Vector
    index: UpdatesIndex
    data, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], data)

jac ¤

jac(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
18
19
20
21
def jac(self, u: Vector) -> Updates:
    data: Vector = jax.grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
47
48
49
50
def jac_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    jac: Updates = self.jac(u)
    hess_diag: Updates = self.hess_diag(u)
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
52
53
54
55
56
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(u, p)
    return outputs

mixed_derivative_prod_activation ¤

mixed_derivative_prod_activation(
    u: Vector, p: Vector
) -> Vector
Source code in src/liblaf/apple/jax/sim/energy/elastic/_phace_active.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def mixed_derivative_prod_activation(self, u: Vector, p: Vector) -> Vector:
    def jac(q: Float[Array, "c 6"]) -> Vector:
        energy: Self = attrs.evolve(self, activation=q)
        data: UpdatesData
        index: UpdatesIndex
        data, index = energy.jac(u)
        jac: Vector = jax.ops.segment_sum(data, index, num_segments=u.shape[0])
        return jac

    vjp: Callable[[Vector], Float[Array, "c 6"]]
    _, vjp = jax.vjp(jac, self.activation)

    output: Float[Array, "c 6"]
    (output,) = vjp(p)
    return output

PhaceStatic ¤

Bases: Elastic

Parameters:

  • id (str, default: <dynamic> ) –
  • requires_grad (Sequence[str], default: () ) –
  • region (Region) –
  • mu (Float[Array, c]) –
  • lambda_ (Float[Array, c]) –

Methods:

Attributes:

id class-attribute instance-attribute ¤

id: str = field(
    default=Factory(_default_id, takes_self=True),
    kw_only=True,
)

lambda_ class-attribute instance-attribute ¤

lambda_: Float[Array, ' c'] = array()

mu class-attribute instance-attribute ¤

mu: Float[Array, ' c'] = array()

region instance-attribute ¤

region: Region

requires_grad class-attribute instance-attribute ¤

requires_grad: Sequence[str] = field(
    default=(), kw_only=True
)

energy_density ¤

energy_density(
    F: Float[Array, "c q J J"],
) -> Float[Array, "c q"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/_phace_static.py
27
28
29
30
31
32
33
34
35
36
37
@override
def energy_density(self, F: Float[Array, "c q J J"]) -> Float[Array, "c q"]:
    lambda_: Float[Array, " c #q"] = self.lambda_[:, jnp.newaxis]
    mu: Float[Array, " c #q"] = self.mu[:, jnp.newaxis]
    R: Float[Array, "c q J J"]
    R, _ = math.polar_rv(F)
    J: Float[Array, "c q"] = jnp.linalg.det(F)
    Psi_ARAP: Float[Array, "c q"] = mu * math.fro_norm_square(F - R)
    Psi_volume_preserving: Float[Array, "c q"] = lambda_ * (J - 1.0) ** 2
    Psi: Float[Array, "c q"] = 2.0 * Psi_ARAP + Psi_volume_preserving
    return Psi

from_geometry classmethod ¤

from_geometry(
    geometry: Geometry,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
22
23
24
25
26
27
28
29
@classmethod
def from_geometry(
    cls, geometry: Geometry, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    region: Region = Region.from_geometry(
        geometry, grad=True, quadrature=quadrature
    )
    return cls.from_region(region, **kwargs)

from_pyvista classmethod ¤

from_pyvista(
    mesh: UnstructuredGrid,
    *,
    quadrature: Scheme | None = None,
    **kwargs,
) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
31
32
33
34
35
36
@classmethod
def from_pyvista(
    cls, mesh: pv.UnstructuredGrid, *, quadrature: Scheme | None = None, **kwargs
) -> Self:
    geometry: Geometry = Geometry.from_pyvista(mesh)
    return cls.from_geometry(geometry, quadrature=quadrature, **kwargs)

from_region classmethod ¤

from_region(region: Region, **kwargs) -> Self
Source code in src/liblaf/apple/jax/sim/energy/elastic/_phace_static.py
17
18
19
20
21
22
23
24
25
@override
@classmethod
def from_region(cls, region: Region, **kwargs) -> Self:
    return cls(
        region=region,
        lambda_=region.cell_data["lambda"],
        mu=region.cell_data["mu"],
        **kwargs,
    )

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/elastic/_elastic.py
38
39
40
41
42
@override
def fun(self, u: Vector) -> Scalar:
    F: Float[Array, "c q J J"] = self.region.deformation_gradient(u)
    Psi: Float[Array, "c q"] = self.energy_density(F)
    return self.region.integrate(Psi).sum()

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
40
41
42
43
44
45
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Updates]:
    value: Scalar
    data: Vector
    value, data = jax.value_and_grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return value, (data, index)

hess_diag ¤

hess_diag(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
23
24
25
26
def hess_diag(self, u: Vector) -> Updates:
    data: Vector = math.hess_diag(self.fun, u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
28
29
30
31
32
def hess_prod(self, u: Vector, p: Vector) -> Updates:
    data: Vector
    _, data = jax.jvp(jax.grad(self.fun), (u,), (p,))
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
34
35
36
37
38
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    data: Vector
    index: UpdatesIndex
    data, index = self.hess_prod(u, p)
    return jnp.vdot(p[index], data)

jac ¤

jac(u: Vector) -> Updates
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
18
19
20
21
def jac(self, u: Vector) -> Updates:
    data: Vector = jax.grad(self.fun)(u)
    index: UpdatesIndex = jnp.arange(data.shape[0])
    return data, index

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
47
48
49
50
def jac_and_hess_diag(self, u: Vector) -> tuple[Updates, Updates]:
    jac: Updates = self.jac(u)
    hess_diag: Updates = self.hess_diag(u)
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/sim/energy/_energy.py
52
53
54
55
56
def mixed_derivative_prod(self, u: Vector, p: Vector) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(u, p)
    return outputs

make_activation ¤

make_activation(
    activation: Float[Array, "c 6"],
) -> Float[Array, "c 3 3"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/utils.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def make_activation(activation: Float[Array, "c 6"]) -> Float[Array, "c 3 3"]:
    n_cells: int = activation.shape[0]
    A: Float[Array, "c 3 3"] = jnp.empty((n_cells, 3, 3), activation.dtype)
    A = A.at[:, 0, 0].set(activation[:, 0])
    A = A.at[:, 1, 1].set(activation[:, 1])
    A = A.at[:, 2, 2].set(activation[:, 2])
    A = A.at[:, 0, 1].set(activation[:, 3])
    A = A.at[:, 0, 2].set(activation[:, 4])
    A = A.at[:, 1, 2].set(activation[:, 5])
    A = A.at[:, 1, 0].set(activation[:, 3])
    A = A.at[:, 2, 0].set(activation[:, 4])
    A = A.at[:, 2, 1].set(activation[:, 5])
    # A += jnp.identity(3, activation.dtype)
    return A

rest_activation ¤

rest_activation(
    n_cells: int = 1, dtype: DTypeLike = float
) -> Float[Array, "c 6"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/utils.py
22
23
24
25
26
27
def rest_activation(n_cells: int = 1, dtype: DTypeLike = float) -> Float[Array, "c 6"]:
    activation: Float[Array, "c 6"] = jnp.zeros((n_cells, 6), dtype)
    activation = activation.at[:, 0].set(1.0)
    activation = activation.at[:, 1].set(1.0)
    activation = activation.at[:, 2].set(1.0)
    return activation

transform_activation ¤

transform_activation(
    activation: Float[Array, "#c 6"],
    orientation: Float[Array, "#c 3 3"],
    *,
    inverse: bool = False,
) -> Float[Array, "c 6"]
Source code in src/liblaf/apple/jax/sim/energy/elastic/utils.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def transform_activation(
    activation: Float[Array, "#c 6"],
    orientation: Float[Array, "#c 3 3"],
    *,
    inverse: bool = False,
) -> Float[Array, "c 6"]:
    activation_mat: Float[Array, "c 3 3"] = make_activation(activation)
    if inverse:
        orientation = orientation.mT
    transformed_mat: Float[Array, "c 3 3"] = (
        orientation.mT @ activation_mat @ orientation
    )
    n_cells: int = transformed_mat.shape[0]
    transformed: Float[Array, "c 6"] = jnp.empty((n_cells, 6), activation.dtype)
    transformed = transformed.at[:, 0].set(transformed_mat[:, 0, 0])
    transformed = transformed.at[:, 1].set(transformed_mat[:, 1, 1])
    transformed = transformed.at[:, 2].set(transformed_mat[:, 2, 2])
    transformed = transformed.at[:, 3].set(transformed_mat[:, 0, 1])
    transformed = transformed.at[:, 4].set(transformed_mat[:, 0, 2])
    transformed = transformed.at[:, 5].set(transformed_mat[:, 1, 2])
    return transformed