Skip to content

liblaf.apple.jax ¤

Modules:

  • energies
  • fem
  • model

Classes:

  • Dirichlet
  • DirichletBuilder
  • Element

    Base-class for a finite element which provides methods for plotting.

  • ElementTetra
  • Geometry
  • GeometryTetra
  • GeometryTriangle
  • JaxEnergy
  • JaxEnergyState
  • JaxMassSpring
  • JaxMassSpringPrestrain
  • JaxModel
  • JaxModelBuilder
  • JaxModelState
  • JaxPointForce
  • QuadratureTetra
  • Region
  • Scheme

Dirichlet ¤

Parameters:

  • dim ¤

    (int) –
  • dirichlet_index ¤

    (Integer[Array, dirichlet]) –
  • dirichlet_value ¤

    (Float[Array, dirichlet]) –
  • fixed_mask ¤

    (Bool[Array, 'points dim']) –
  • free_index ¤

    (Integer[Array, free]) –
  • n_points ¤

    (int) –

Methods:

  • get_fixed
  • get_free
  • set_fixed
  • set_free
  • to_full

Attributes:

  • dim (int) –
  • dirichlet_index (Integer[Array, ' dirichlet']) –
  • dirichlet_value (Float[Array, ' dirichlet']) –
  • fixed_mask (Bool[Array, 'points dim']) –
  • free_index (Integer[Array, ' free']) –
  • n_dirichlet (int) –
  • n_free (int) –
  • n_full (int) –
  • n_points (int) –

dim class-attribute instance-attribute ¤

dim: int = static()

dirichlet_index instance-attribute ¤

dirichlet_index: Integer[Array, ' dirichlet']

dirichlet_value instance-attribute ¤

dirichlet_value: Float[Array, ' dirichlet']

fixed_mask instance-attribute ¤

fixed_mask: Bool[Array, 'points dim']

free_index instance-attribute ¤

free_index: Integer[Array, ' free']

n_dirichlet property ¤

n_dirichlet: int

n_free property ¤

n_free: int

n_full property ¤

n_full: int

n_points class-attribute instance-attribute ¤

n_points: int = static()

get_fixed ¤

get_fixed(
    full: Float[Array, "points dim"],
) -> Float[Array, " dirichlet"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
27
28
29
@jarp.jit(inline=True)
def get_fixed(self, full: Float[Array, "points dim"]) -> Float[Array, " dirichlet"]:
    return full.flatten()[self.dirichlet_index]

get_free ¤

get_free(
    full: Float[Array, "points dim"],
) -> Float[Array, " free"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
31
32
33
@jarp.jit(inline=True)
def get_free(self, full: Float[Array, "points dim"]) -> Float[Array, " free"]:
    return full.flatten()[self.free_index]

set_fixed ¤

set_fixed(
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
35
36
37
38
39
40
41
42
43
@jarp.jit(inline=True)
def set_fixed(
    self,
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]:
    if values is None:
        values = self.dirichlet_value
    return full.flatten().at[self.dirichlet_index].set(values).reshape(full.shape)

set_free ¤

set_free(
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " free"],
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
45
46
47
48
49
@jarp.jit(inline=True)
def set_free(
    self, full: Float[Array, "points dim"], values: Float[ArrayLike, " free"]
) -> Float[Array, "points dim"]:
    return full.flatten().at[self.free_index].set(values).reshape(full.shape)

to_full ¤

to_full(
    free: Float[Array, " free"],
    dirichlet: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
51
52
53
54
55
56
57
58
59
60
61
62
@jarp.jit(inline=True)
def to_full(
    self,
    free: Float[Array, " free"],
    dirichlet: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]:
    full: Float[Array, "points dim"] = jnp.empty(
        (self.n_points, self.dim), free.dtype
    )
    full = self.set_free(full, free)
    full = self.set_fixed(full, dirichlet)
    return full

DirichletBuilder ¤

DirichletBuilder(dim: int = 3)

Parameters:

  • mask ¤

    (Bool[ndarray, 'points dim']) –
  • value ¤

    (Float[ndarray, 'points dim']) –

Methods:

  • add_pyvista
  • finalize
  • resize

Attributes:

  • dim (int) –
  • mask (Bool[ndarray, 'points dim']) –
  • n_points (int) –
  • value (Float[ndarray, 'points dim']) –
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
19
20
21
22
def __init__(self, dim: int = 3) -> None:
    mask: Bool[np.ndarray, "points dim"] = np.empty((0, dim), bool)
    value: Float[np.ndarray, "points dim"] = np.empty((0, dim))
    self.__attrs_init__(mask=mask, value=value)  # pyright: ignore[reportAttributeAccessIssue]

dim property ¤

dim: int

mask instance-attribute ¤

mask: Bool[ndarray, 'points dim']

n_points property ¤

n_points: int

value instance-attribute ¤

value: Float[ndarray, 'points dim']

add_pyvista ¤

add_pyvista(obj: DataSet) -> None
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
32
33
34
35
36
37
38
39
40
41
42
def add_pyvista(self, obj: pv.DataSet) -> None:
    point_id = obj.point_data[GLOBAL_POINT_ID]
    self.resize(point_id.max() + 1)
    dirichlet_mask: Bool[np.ndarray, "points dim"] = _left_broadcast_to(
        obj.point_data[DIRICHLET_MASK], (obj.n_points, self.dim)
    )
    dirichlet_value: Float[np.ndarray, "points dim"] = _left_broadcast_to(
        obj.point_data[DIRICHLET_VALUE], (obj.n_points, self.dim)
    )
    self.mask[point_id] = dirichlet_mask
    self.value[point_id] = dirichlet_value

finalize ¤

finalize() -> Dirichlet
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
44
45
46
47
48
49
50
51
52
53
54
def finalize(self) -> Dirichlet:
    mask: Bool[Array, "points dim"] = jnp.asarray(self.mask)
    dirichlet_index: Integer[Array, " dirichlet"] = jnp.flatnonzero(mask)
    return Dirichlet(
        dim=self.dim,
        dirichlet_index=dirichlet_index,
        dirichlet_value=jnp.asarray(self.value.ravel()[dirichlet_index]),
        fixed_mask=mask,
        free_index=jnp.flatnonzero(~mask),
        n_points=self.n_points,
    )

resize ¤

resize(n_points: int) -> None
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
56
57
58
59
60
61
def resize(self, n_points: int) -> None:
    pad_after: int = n_points - self.n_points
    if pad_after <= 0:
        return
    self.mask = np.pad(self.mask, ((0, pad_after), (0, 0)), constant_values=False)
    self.value = np.pad(self.value, ((0, pad_after), (0, 0)), constant_values=0.0)

Element ¤

Base-class for a finite element which provides methods for plotting.

References
  1. felupe.Element

Methods:

  • function

    Return the shape functions at given coordinates.

  • gradient
  • hessian

Attributes:

  • cells (Integer[Array, ' points']) –
  • dim (int) –
  • n_points (int) –
  • points (Float[Array, 'points dim']) –
  • quadrature (Scheme) –

cells property ¤

cells: Integer[Array, ' points']

dim property ¤

dim: int

n_points property ¤

n_points: int

points property ¤

points: Float[Array, 'points dim']

quadrature property ¤

quadrature: Scheme

function ¤

function(
    coords: Float[Array, " dim"],
) -> Float[Array, " points"]

Return the shape functions at given coordinates.

Source code in src/liblaf/apple/jax/fem/element/_element.py
38
39
40
def function(self, coords: Float[Array, " dim"]) -> Float[Array, " points"]:
    """Return the shape functions at given coordinates."""
    raise NotImplementedError

gradient ¤

gradient(
    coords: Float[Array, " dim"],
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/fem/element/_element.py
42
43
def gradient(self, coords: Float[Array, " dim"]) -> Float[Array, "points dim"]:
    return jax.jacobian(self.function)(coords)

hessian ¤

hessian(
    coords: Float[Array, " dim"],
) -> Float[Array, "points dim dim"]
Source code in src/liblaf/apple/jax/fem/element/_element.py
45
46
def hessian(self, coords: Float[Array, " dim"]) -> Float[Array, "points dim dim"]:
    return jax.hessian(self.function)(coords)

ElementTetra ¤

Bases: Element


              flowchart TD
              liblaf.apple.jax.ElementTetra[ElementTetra]
              liblaf.apple.jax.fem.element._element.Element[Element]

                              liblaf.apple.jax.fem.element._element.Element --> liblaf.apple.jax.ElementTetra
                


              click liblaf.apple.jax.ElementTetra href "" "liblaf.apple.jax.ElementTetra"
              click liblaf.apple.jax.fem.element._element.Element href "" "liblaf.apple.jax.fem.element._element.Element"
            

Methods:

  • function

    Return the shape functions at given coordinates.

  • gradient
  • hessian

Attributes:

  • cells (Integer[Array, ' points']) –
  • dim (int) –
  • n_points (int) –
  • points (Float[Array, 'points dim']) –
  • quadrature (QuadratureTetra) –

cells property ¤

cells: Integer[Array, ' points']

dim property ¤

dim: int

n_points property ¤

n_points: int

points property ¤

points: Float[Array, 'points dim']

quadrature property ¤

quadrature: QuadratureTetra

function ¤

function(
    coords: Float[Array, " dim"],
) -> Float[Array, "points=4"]

Return the shape functions at given coordinates.

Source code in src/liblaf/apple/jax/fem/element/_tetra.py
31
32
33
34
@override
def function(self, coords: Float[Array, " dim"]) -> Float[Array, "points=4"]:
    r, s, t = coords
    return jnp.asarray([1.0 - r - s - t, r, s, t])

gradient ¤

gradient(
    coords: Float[Array, " dim"],
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/fem/element/_tetra.py
36
37
38
39
40
41
42
43
44
45
@override
def gradient(self, coords: Float[Array, " dim"]) -> Float[Array, "points dim"]:
    return jnp.asarray(
        [
            [-1.0, -1.0, -1.0],
            [1.0, 0.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 0.0, 1.0],
        ]
    )

hessian ¤

hessian(
    coords: Float[Array, " dim"],
) -> Float[Array, "points dim dim"]
Source code in src/liblaf/apple/jax/fem/element/_tetra.py
47
48
49
@override
def hessian(self, coords: Float[Array, " dim"]) -> Float[Array, "points dim dim"]:
    return jnp.zeros((4, 3, 3))

Geometry ¤

Parameters:

Methods:

  • from_pyvista

Attributes:

  • cell_data (DataSetAttributes) –
  • cells_global (Integer[Array, 'c a']) –
  • cells_local (Integer[Array, 'c a']) –
  • element (Element) –
  • global_point_id (Integer[Array, 'p J']) –
  • mesh (DataSet) –
  • n_cells (int) –
  • point_data (DataSetAttributes) –
  • points (Float[Array, 'p J']) –

cell_data property ¤

cell_data: DataSetAttributes

cells_global property ¤

cells_global: Integer[Array, 'c a']

cells_local property ¤

cells_local: Integer[Array, 'c a']

element property ¤

element: Element

global_point_id property ¤

global_point_id: Integer[Array, 'p J']

mesh class-attribute instance-attribute ¤

mesh: DataSet = field()

n_cells property ¤

n_cells: int

point_data property ¤

point_data: DataSetAttributes

points property ¤

points: Float[Array, 'p J']

from_pyvista classmethod ¤

from_pyvista(mesh: DataObject) -> Geometry
Source code in src/liblaf/apple/jax/fem/geometry/_geometry.py
16
17
18
19
20
21
22
23
24
25
@classmethod
def from_pyvista(cls, mesh: pv.DataObject) -> Geometry:
    from ._tetra import GeometryTetra
    from ._triangle import GeometryTriangle

    if isinstance(mesh, pv.PolyData):
        return GeometryTriangle.from_pyvista(mesh)
    if isinstance(mesh, pv.UnstructuredGrid):
        return GeometryTetra.from_pyvista(mesh)
    raise NotImplementedError

GeometryTetra ¤

Bases: Geometry


              flowchart TD
              liblaf.apple.jax.GeometryTetra[GeometryTetra]
              liblaf.apple.jax.fem.geometry._geometry.Geometry[Geometry]

                              liblaf.apple.jax.fem.geometry._geometry.Geometry --> liblaf.apple.jax.GeometryTetra
                


              click liblaf.apple.jax.GeometryTetra href "" "liblaf.apple.jax.GeometryTetra"
              click liblaf.apple.jax.fem.geometry._geometry.Geometry href "" "liblaf.apple.jax.fem.geometry._geometry.Geometry"
            

Parameters:

  • mesh ¤

    (UnstructuredGrid) –

Methods:

  • from_pyvista

Attributes:

  • cell_data (DataSetAttributes) –
  • cells_global (Integer[Array, 'c a']) –
  • cells_local (Integer[Array, 'c a']) –
  • element (ElementTetra) –
  • global_point_id (Integer[Array, 'p J']) –
  • mesh (UnstructuredGrid) –
  • n_cells (int) –
  • point_data (DataSetAttributes) –
  • points (Float[Array, 'p J']) –

cell_data property ¤

cell_data: DataSetAttributes

cells_global property ¤

cells_global: Integer[Array, 'c a']

cells_local property ¤

cells_local: Integer[Array, 'c a']

element property ¤

element: ElementTetra

global_point_id property ¤

global_point_id: Integer[Array, 'p J']

mesh class-attribute instance-attribute ¤

mesh: UnstructuredGrid = field()

n_cells property ¤

n_cells: int

point_data property ¤

point_data: DataSetAttributes

points property ¤

points: Float[Array, 'p J']

from_pyvista classmethod ¤

from_pyvista(mesh: UnstructuredGrid) -> Self
Source code in src/liblaf/apple/jax/fem/geometry/_tetra.py
17
18
19
20
21
@override
@classmethod
def from_pyvista(cls, mesh: pv.UnstructuredGrid) -> Self:  # pyright: ignore[reportIncompatibleMethodOverride]
    self: Self = cls(mesh=mesh)
    return self

GeometryTriangle ¤

Bases: Geometry


              flowchart TD
              liblaf.apple.jax.GeometryTriangle[GeometryTriangle]
              liblaf.apple.jax.fem.geometry._geometry.Geometry[Geometry]

                              liblaf.apple.jax.fem.geometry._geometry.Geometry --> liblaf.apple.jax.GeometryTriangle
                


              click liblaf.apple.jax.GeometryTriangle href "" "liblaf.apple.jax.GeometryTriangle"
              click liblaf.apple.jax.fem.geometry._geometry.Geometry href "" "liblaf.apple.jax.fem.geometry._geometry.Geometry"
            

Parameters:

  • mesh ¤

    (PolyData) –

Methods:

  • from_pyvista

Attributes:

  • cell_data (DataSetAttributes) –
  • cells_global (Integer[Array, 'c a']) –
  • cells_local (Integer[Array, 'c a']) –
  • element (Element) –
  • global_point_id (Integer[Array, 'p J']) –
  • mesh (PolyData) –
  • n_cells (int) –
  • point_data (DataSetAttributes) –
  • points (Float[Array, 'p J']) –

cell_data property ¤

cell_data: DataSetAttributes

cells_global property ¤

cells_global: Integer[Array, 'c a']

cells_local property ¤

cells_local: Integer[Array, 'c a']

element property ¤

element: Element

global_point_id property ¤

global_point_id: Integer[Array, 'p J']

mesh class-attribute instance-attribute ¤

mesh: PolyData = static()

n_cells property ¤

n_cells: int

point_data property ¤

point_data: DataSetAttributes

points property ¤

points: Float[Array, 'p J']

from_pyvista classmethod ¤

from_pyvista(mesh: PolyData) -> Self
Source code in src/liblaf/apple/jax/fem/geometry/_triangle.py
15
16
17
18
19
20
@override
@classmethod
def from_pyvista(cls, mesh: pv.PolyData) -> Self:  # pyright: ignore[reportIncompatibleMethodOverride]
    mesh = mesh.triangulate()  # pyright: ignore[reportAssignmentType]
    self: Self = cls(mesh=mesh)
    return self

JaxEnergy ¤

Parameters:

  • name ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –

Methods:

  • fun
  • grad
  • grad_and_hess_diag
  • hess_diag
  • hess_prod
  • hess_quad
  • init_state
  • mixed_derivative_prod
  • update
  • update_materials
  • value_and_grad

Attributes:

name class-attribute instance-attribute ¤

name: str = static(default=name_factory, kw_only=True)

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = static(
    default=frozenset(), kw_only=True
)

fun ¤

fun(state: JaxEnergyState, u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
34
35
def fun(self, state: JaxEnergyState, u: Vector) -> Scalar:
    raise NotImplementedError

grad ¤

grad(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
37
38
39
40
@jarp.jit(inline=True)
def grad(self, state: JaxEnergyState, u: Vector) -> Updates:
    values: Vector = eqx.filter_grad(jarp.partial(self.fun, state))(u)
    return values, jnp.arange(u.shape[0])

grad_and_hess_diag ¤

grad_and_hess_diag(
    state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
67
68
69
70
71
@jarp.jit(inline=True)
def grad_and_hess_diag(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]:
    return self.grad(state, u), self.hess_diag(state, u)

hess_diag ¤

hess_diag(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
42
43
def hess_diag(self, state: JaxEnergyState, u: Vector) -> Updates:
    raise NotImplementedError

hess_prod ¤

hess_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
45
46
47
48
49
@jarp.jit(inline=True)
def hess_prod(self, state: JaxEnergyState, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(jarp.partial(self.fun, state)), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
51
52
53
54
55
56
@jarp.jit(inline=True)
def hess_quad(self, state: JaxEnergyState, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(state, u, p)
    return jnp.vdot(p[index], values)

init_state ¤

init_state(u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
25
26
def init_state(self, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return JaxEnergyState()

mixed_derivative_prod ¤

mixed_derivative_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/model/_energy.py
73
74
75
76
77
78
79
80
@jarp.jit(inline=True)
def mixed_derivative_prod(
    self, state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(state, u, p)
    return outputs

update ¤

update(state: JaxEnergyState, u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
28
29
def update(self, state: JaxEnergyState, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return state

update_materials ¤

update_materials(materials: EnergyMaterials) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
31
32
def update_materials(self, materials: EnergyMaterials) -> None:
    pass

value_and_grad ¤

value_and_grad(
    state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
58
59
60
61
62
63
64
65
@jarp.jit(inline=True)
def value_and_grad(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(jarp.partial(self.fun, state))(u)
    return value, (grad, jnp.arange(u.shape[0]))

JaxEnergyState ¤

JaxMassSpring ¤

Bases: JaxEnergy


              flowchart TD
              liblaf.apple.jax.JaxMassSpring[JaxMassSpring]
              liblaf.apple.jax.model._energy.JaxEnergy[JaxEnergy]

                              liblaf.apple.jax.model._energy.JaxEnergy --> liblaf.apple.jax.JaxMassSpring
                


              click liblaf.apple.jax.JaxMassSpring href "" "liblaf.apple.jax.JaxMassSpring"
              click liblaf.apple.jax.model._energy.JaxEnergy href "" "liblaf.apple.jax.model._energy.JaxEnergy"
            

Parameters:

  • name ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –
  • edges ¤

    (Integer[Array, 'edges 2']) –
  • length ¤

    (Float[Array, edges]) –
  • points ¤

    (Float[Array, 'edges 2 3']) –
  • stiffness ¤

    (Float[Array, edges]) –

Methods:

  • from_pyvista
  • fun
  • grad
  • grad_and_hess_diag
  • hess_diag
  • hess_prod
  • hess_quad
  • init_state
  • mixed_derivative_prod
  • update
  • update_materials
  • value_and_grad

Attributes:

  • edges (Integer[Array, ' edges 2']) –
  • length (Float[Array, ' edges']) –
  • n_edges (int) –
  • name (str) –
  • points (Float[Array, 'edges 2 3']) –
  • requires_grad (frozenset[str]) –
  • stiffness (Float[Array, ' edges']) –

edges instance-attribute ¤

edges: Integer[Array, ' edges 2']

length instance-attribute ¤

length: Float[Array, ' edges']

n_edges property ¤

n_edges: int

name class-attribute instance-attribute ¤

name: str = static(default=name_factory, kw_only=True)

points instance-attribute ¤

points: Float[Array, 'edges 2 3']

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = static(
    default=frozenset(), kw_only=True
)

stiffness instance-attribute ¤

stiffness: Float[Array, ' edges']

from_pyvista classmethod ¤

from_pyvista(obj: PolyData) -> Self
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@classmethod
def from_pyvista(cls, obj: pv.PolyData) -> Self:
    if LENGTH not in obj.cell_data:
        obj = obj.compute_cell_sizes(length=True, area=False, volume=False)  # pyright: ignore[reportAssignmentType]
    point_id: Integer[np.ndarray, " points"] = obj.point_data[GLOBAL_POINT_ID]
    edges: Integer[np.ndarray, "edges 2"] = obj.lines.reshape((-1, 3))[:, 1:]
    length: Float[Array, " edges"] = jnp.asarray(obj.cell_data[LENGTH])
    if jnp.any(length < 0.0):
        logger.warning("Length < 0")
    return cls(
        edges=jnp.asarray(point_id[edges]),
        length=length,
        points=jnp.asarray(obj.points[edges]),
        stiffness=jnp.asarray(obj.cell_data[STIFFNESS]),
    )

fun ¤

fun(state: JaxEnergyState, u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
49
50
51
52
53
54
55
56
57
58
@override
def fun(self, state: JaxEnergyState, u: Vector) -> Scalar:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    energy: Float[Array, " edges"] = (
        0.5
        * self.stiffness
        * jnp.square(jnp.linalg.norm(delta, axis=-1) - self.length)
    )
    return jnp.sum(energy)

grad ¤

grad(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
def grad(self, state: JaxEnergyState, u: Vector) -> Updates:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    length: Float[Array, " edges"] = jnp.linalg.norm(delta, axis=-1)
    direction: Float[Array, "edges 3"] = (
        delta / jnp.where(length > 0, length, 1.0)[:, jnp.newaxis]
    )
    force: Float[Array, "edges 3"] = (
        self.stiffness[:, jnp.newaxis]
        * (length - self.length)[:, jnp.newaxis]
        * direction
    )
    grad: Float[Array, "edges 2 3"] = jnp.stack([-force, force], axis=1)
    return grad.reshape(-1, 3), self.edges.flatten()

grad_and_hess_diag ¤

grad_and_hess_diag(
    state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
67
68
69
70
71
@jarp.jit(inline=True)
def grad_and_hess_diag(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]:
    return self.grad(state, u), self.hess_diag(state, u)

hess_diag ¤

hess_diag(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
76
77
78
79
80
81
@override
def hess_diag(self, state: JaxEnergyState, u: Vector) -> Updates:
    values: Float[Array, "edges*2 3"] = einops.repeat(
        self.stiffness, "edges -> (edges i) j", i=2, j=3
    )
    return values, self.edges.flatten()

hess_prod ¤

hess_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
45
46
47
48
49
@jarp.jit(inline=True)
def hess_prod(self, state: JaxEnergyState, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(jarp.partial(self.fun, state)), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
51
52
53
54
55
56
@jarp.jit(inline=True)
def hess_quad(self, state: JaxEnergyState, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(state, u, p)
    return jnp.vdot(p[index], values)

init_state ¤

init_state(u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
25
26
def init_state(self, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return JaxEnergyState()

mixed_derivative_prod ¤

mixed_derivative_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/model/_energy.py
73
74
75
76
77
78
79
80
@jarp.jit(inline=True)
def mixed_derivative_prod(
    self, state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(state, u, p)
    return outputs

update ¤

update(state: JaxEnergyState, u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
28
29
def update(self, state: JaxEnergyState, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return state

update_materials ¤

update_materials(materials: EnergyMaterials) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
31
32
def update_materials(self, materials: EnergyMaterials) -> None:
    pass

value_and_grad ¤

value_and_grad(
    state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
58
59
60
61
62
63
64
65
@jarp.jit(inline=True)
def value_and_grad(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(jarp.partial(self.fun, state))(u)
    return value, (grad, jnp.arange(u.shape[0]))

JaxMassSpringPrestrain ¤

Bases: JaxMassSpring


              flowchart TD
              liblaf.apple.jax.JaxMassSpringPrestrain[JaxMassSpringPrestrain]
              liblaf.apple.jax.energies._mass_spring.JaxMassSpring[JaxMassSpring]
              liblaf.apple.jax.model._energy.JaxEnergy[JaxEnergy]

                              liblaf.apple.jax.energies._mass_spring.JaxMassSpring --> liblaf.apple.jax.JaxMassSpringPrestrain
                                liblaf.apple.jax.model._energy.JaxEnergy --> liblaf.apple.jax.energies._mass_spring.JaxMassSpring
                



              click liblaf.apple.jax.JaxMassSpringPrestrain href "" "liblaf.apple.jax.JaxMassSpringPrestrain"
              click liblaf.apple.jax.energies._mass_spring.JaxMassSpring href "" "liblaf.apple.jax.energies._mass_spring.JaxMassSpring"
              click liblaf.apple.jax.model._energy.JaxEnergy href "" "liblaf.apple.jax.model._energy.JaxEnergy"
            

Parameters:

  • name ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –
  • edges ¤

    (Integer[Array, 'edges 2']) –
  • length ¤

    (Float[Array, edges]) –
  • points ¤

    (Float[Array, 'edges 2 3']) –
  • stiffness ¤

    (Float[Array, edges]) –

Methods:

  • from_pyvista
  • fun
  • grad
  • grad_and_hess_diag
  • hess_diag
  • hess_prod
  • hess_quad
  • init_state
  • mixed_derivative_prod
  • update
  • update_materials
  • value_and_grad

Attributes:

  • edges (Integer[Array, ' edges 2']) –
  • length (Float[Array, ' edges']) –
  • n_edges (int) –
  • name (str) –
  • points (Float[Array, 'edges 2 3']) –
  • requires_grad (frozenset[str]) –
  • stiffness (Float[Array, ' edges']) –

edges instance-attribute ¤

edges: Integer[Array, ' edges 2']

length instance-attribute ¤

length: Float[Array, ' edges']

n_edges property ¤

n_edges: int

name class-attribute instance-attribute ¤

name: str = static(default=name_factory, kw_only=True)

points instance-attribute ¤

points: Float[Array, 'edges 2 3']

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = static(
    default=frozenset(), kw_only=True
)

stiffness instance-attribute ¤

stiffness: Float[Array, ' edges']

from_pyvista classmethod ¤

from_pyvista(obj: PolyData) -> Self
Source code in src/liblaf/apple/jax/energies/_mass_spring_prestrain.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@classmethod
def from_pyvista(cls, obj: pv.PolyData) -> Self:
    if LENGTH not in obj.cell_data:
        obj = obj.compute_cell_sizes(length=True, area=False, volume=False)  # pyright: ignore[reportAssignmentType]
    point_id: Integer[np.ndarray, " points"] = obj.point_data[GLOBAL_POINT_ID]
    edges: Integer[np.ndarray, "edges 2"] = obj.lines.reshape((-1, 3))[:, 1:]
    length: Float[Array, " edges"] = jnp.asarray(obj.cell_data[LENGTH])
    if jnp.any(length < 0.0):
        logger.warning("Length < 0")
    prestrain: Float[Array, " edges"] = jnp.asarray(obj.cell_data[PRESTRAIN])
    return cls(
        edges=jnp.asarray(point_id[edges]),
        length=length * (1.0 + prestrain),
        points=jnp.asarray(obj.points[edges]),
        stiffness=jnp.asarray(obj.cell_data[STIFFNESS]),
    )

fun ¤

fun(state: JaxEnergyState, u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
49
50
51
52
53
54
55
56
57
58
@override
def fun(self, state: JaxEnergyState, u: Vector) -> Scalar:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    energy: Float[Array, " edges"] = (
        0.5
        * self.stiffness
        * jnp.square(jnp.linalg.norm(delta, axis=-1) - self.length)
    )
    return jnp.sum(energy)

grad ¤

grad(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@override
def grad(self, state: JaxEnergyState, u: Vector) -> Updates:
    x: Float[Array, "edges 2 3"] = self.points + u[self.edges]
    delta: Float[Array, "edges 3"] = x[:, 1, :] - x[:, 0, :]
    length: Float[Array, " edges"] = jnp.linalg.norm(delta, axis=-1)
    direction: Float[Array, "edges 3"] = (
        delta / jnp.where(length > 0, length, 1.0)[:, jnp.newaxis]
    )
    force: Float[Array, "edges 3"] = (
        self.stiffness[:, jnp.newaxis]
        * (length - self.length)[:, jnp.newaxis]
        * direction
    )
    grad: Float[Array, "edges 2 3"] = jnp.stack([-force, force], axis=1)
    return grad.reshape(-1, 3), self.edges.flatten()

grad_and_hess_diag ¤

grad_and_hess_diag(
    state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
67
68
69
70
71
@jarp.jit(inline=True)
def grad_and_hess_diag(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]:
    return self.grad(state, u), self.hess_diag(state, u)

hess_diag ¤

hess_diag(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_mass_spring.py
76
77
78
79
80
81
@override
def hess_diag(self, state: JaxEnergyState, u: Vector) -> Updates:
    values: Float[Array, "edges*2 3"] = einops.repeat(
        self.stiffness, "edges -> (edges i) j", i=2, j=3
    )
    return values, self.edges.flatten()

hess_prod ¤

hess_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
45
46
47
48
49
@jarp.jit(inline=True)
def hess_prod(self, state: JaxEnergyState, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(jarp.partial(self.fun, state)), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
51
52
53
54
55
56
@jarp.jit(inline=True)
def hess_quad(self, state: JaxEnergyState, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(state, u, p)
    return jnp.vdot(p[index], values)

init_state ¤

init_state(u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
25
26
def init_state(self, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return JaxEnergyState()

mixed_derivative_prod ¤

mixed_derivative_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/model/_energy.py
73
74
75
76
77
78
79
80
@jarp.jit(inline=True)
def mixed_derivative_prod(
    self, state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(state, u, p)
    return outputs

update ¤

update(state: JaxEnergyState, u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
28
29
def update(self, state: JaxEnergyState, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return state

update_materials ¤

update_materials(materials: EnergyMaterials) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
31
32
def update_materials(self, materials: EnergyMaterials) -> None:
    pass

value_and_grad ¤

value_and_grad(
    state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
58
59
60
61
62
63
64
65
@jarp.jit(inline=True)
def value_and_grad(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(jarp.partial(self.fun, state))(u)
    return value, (grad, jnp.arange(u.shape[0]))

JaxModel ¤

Parameters:

  • energies ¤

    (dict[str, JaxEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

  • fun
  • grad
  • grad_and_hess_diag
  • hess_diag
  • hess_prod
  • hess_quad
  • init_state
  • mixed_derivative_prod
  • update
  • update_materials
  • value_and_grad

Attributes:

  • energies (dict[str, JaxEnergy]) –

energies class-attribute instance-attribute ¤

energies: dict[str, JaxEnergy] = field(factory=dict)

fun ¤

fun(state: JaxModelState, u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_model.py
38
39
40
41
42
43
@jarp.jit(inline=True)
def fun(self, state: JaxModelState, u: Vector) -> Scalar:
    output: Scalar = jnp.zeros(())
    for energy in self.energies.values():
        output += energy.fun(state.data[energy.name], u)
    return output

grad ¤

grad(state: JaxModelState, u: Vector) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
45
46
47
48
49
50
51
52
53
@jarp.jit(inline=True)
def grad(self, state: JaxModelState, u: Vector) -> Vector:
    output: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        grad: Vector
        index: Index
        grad, index = energy.grad(state[energy.name], u)
        output = output.at[index].add(grad)
    return output

grad_and_hess_diag ¤

grad_and_hess_diag(
    state: JaxModelState, u: Vector
) -> tuple[Vector, Vector]
Source code in src/liblaf/apple/jax/model/_model.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@jarp.jit(inline=True)
def grad_and_hess_diag(
    self, state: JaxModelState, u: Vector
) -> tuple[Vector, Vector]:
    grad: Vector = jnp.zeros_like(u)
    hess_diag: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        grad_i: Vector
        index_g: Index
        hess_diag_i: Vector
        index_h: Index
        (grad_i, index_g), (hess_diag_i, index_h) = energy.grad_and_hess_diag(
            state[energy.name], u
        )
        grad = grad.at[index_g].add(grad_i)
        hess_diag = hess_diag.at[index_h].add(hess_diag_i)
    return grad, hess_diag

hess_diag ¤

hess_diag(state: JaxModelState, u: Vector) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
55
56
57
58
59
60
61
62
63
@jarp.jit(inline=True)
def hess_diag(self, state: JaxModelState, u: Vector) -> Vector:
    output: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        diag: Vector
        index: Index
        diag, index = energy.hess_diag(state[energy.name], u)
        output = output.at[index].add(diag)
    return output

hess_prod ¤

hess_prod(
    state: JaxModelState, u: Vector, p: Vector
) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
65
66
67
68
69
70
71
72
73
@jarp.jit(inline=True)
def hess_prod(self, state: JaxModelState, u: Vector, p: Vector) -> Vector:
    output: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        prod: Vector
        index: Index
        prod, index = energy.hess_prod(state[energy.name], u, p)
        output = output.at[index].add(prod)
    return output

hess_quad ¤

hess_quad(
    state: JaxModelState, u: Vector, p: Vector
) -> Scalar
Source code in src/liblaf/apple/jax/model/_model.py
75
76
77
78
79
80
@jarp.jit(inline=True)
def hess_quad(self, state: JaxModelState, u: Vector, p: Vector) -> Scalar:
    output: Scalar = jnp.zeros(())
    for energy in self.energies.values():
        output += energy.hess_quad(state[energy.name], u, p)
    return output

init_state ¤

init_state(u: Vector) -> JaxModelState
Source code in src/liblaf/apple/jax/model/_model.py
22
23
24
25
26
def init_state(self, u: Vector) -> JaxModelState:
    data: dict[str, JaxEnergyState] = {}
    for energy in self.energies.values():
        data[energy.name] = energy.init_state(u)
    return JaxModelState(u=u, data=data)

mixed_derivative_prod ¤

mixed_derivative_prod(
    state: JaxModelState, u: Vector, p: Vector
) -> dict[str, dict[str, Array]]
Source code in src/liblaf/apple/jax/model/_model.py
82
83
84
85
86
87
88
89
@jarp.jit(inline=True)
def mixed_derivative_prod(
    self, state: JaxModelState, u: Vector, p: Vector
) -> dict[str, dict[str, Array]]:
    return {
        name: energy.mixed_derivative_prod(state[energy.name], u, p)
        for name, energy in self.energies.items()
    }

update ¤

update(state: JaxModelState, u: Vector) -> JaxModelState
Source code in src/liblaf/apple/jax/model/_model.py
28
29
30
31
32
def update(self, state: JaxModelState, u: Vector) -> JaxModelState:
    for energy in self.energies.values():
        energy.update(state.data[energy.name], u)
    state.u = u
    return state

update_materials ¤

update_materials(materials: ModelMaterials) -> None
Source code in src/liblaf/apple/jax/model/_model.py
34
35
36
def update_materials(self, materials: ModelMaterials) -> None:
    for energy_name, energy_materials in materials.items():
        self.energies[energy_name].update_materials(energy_materials)

value_and_grad ¤

value_and_grad(
    state: JaxModelState, u: Vector
) -> tuple[Scalar, Vector]
Source code in src/liblaf/apple/jax/model/_model.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@jarp.jit(inline=True)
def value_and_grad(self, state: JaxModelState, u: Vector) -> tuple[Scalar, Vector]:
    value: Scalar = jnp.zeros(())
    grad: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        value_i: Scalar
        grad_i: Vector
        value_i, (grad_i, index) = energy.value_and_grad(state[energy.name], u)
        value += value_i
        grad = grad.at[index].add(grad_i)
    return value, grad

JaxModelBuilder ¤

Parameters:

  • energies ¤

    (dict[str, JaxEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

  • add_energy
  • finalize

Attributes:

  • energies (dict[str, JaxEnergy]) –

energies class-attribute instance-attribute ¤

energies: dict[str, JaxEnergy] = field(
    factory=dict, kw_only=True
)

add_energy ¤

add_energy(energy: JaxEnergy) -> None
Source code in src/liblaf/apple/jax/model/_builder.py
11
12
def add_energy(self, energy: JaxEnergy) -> None:
    self.energies[energy.name] = energy

finalize ¤

finalize() -> JaxModel
Source code in src/liblaf/apple/jax/model/_builder.py
14
15
def finalize(self) -> JaxModel:
    return JaxModel(energies=self.energies)

JaxModelState ¤

Bases: MutableMapping[str, JaxEnergyState]


              flowchart TD
              liblaf.apple.jax.JaxModelState[JaxModelState]

              

              click liblaf.apple.jax.JaxModelState href "" "liblaf.apple.jax.JaxModelState"
            

Parameters:

  • u ¤

    (Vector) –
  • data ¤

    (dict[str, JaxEnergyState], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

  • __delitem__
  • __getitem__
  • __iter__
  • __len__
  • __setitem__

Attributes:

  • data (dict[str, JaxEnergyState]) –
  • u (Vector) –

data class-attribute instance-attribute ¤

data: dict[str, JaxEnergyState] = field(factory=dict)

u instance-attribute ¤

u: Vector

__delitem__ ¤

__delitem__(key: str) -> None
Source code in src/liblaf/apple/jax/model/_state.py
25
26
def __delitem__(self, key: str) -> None:
    del self.data[key]

__getitem__ ¤

__getitem__(key: str) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_state.py
19
20
def __getitem__(self, key: str) -> JaxEnergyState:
    return self.data[key]

__iter__ ¤

__iter__() -> Iterator[str]
Source code in src/liblaf/apple/jax/model/_state.py
28
29
def __iter__(self) -> Iterator[str]:
    return iter(self.data)

__len__ ¤

__len__() -> int
Source code in src/liblaf/apple/jax/model/_state.py
31
32
def __len__(self) -> int:
    return len(self.data)

__setitem__ ¤

__setitem__(key: str, value: JaxEnergyState) -> None
Source code in src/liblaf/apple/jax/model/_state.py
22
23
def __setitem__(self, key: str, value: JaxEnergyState) -> None:
    self.data[key] = value

JaxPointForce ¤

Bases: JaxEnergy


              flowchart TD
              liblaf.apple.jax.JaxPointForce[JaxPointForce]
              liblaf.apple.jax.model._energy.JaxEnergy[JaxEnergy]

                              liblaf.apple.jax.model._energy.JaxEnergy --> liblaf.apple.jax.JaxPointForce
                


              click liblaf.apple.jax.JaxPointForce href "" "liblaf.apple.jax.JaxPointForce"
              click liblaf.apple.jax.model._energy.JaxEnergy href "" "liblaf.apple.jax.model._energy.JaxEnergy"
            

Parameters:

  • name ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –
  • force ¤

    (Vector) –
  • indices ¤

    (Index) –

Methods:

  • from_pyvista
  • fun
  • grad
  • grad_and_hess_diag
  • hess_diag
  • hess_prod
  • hess_quad
  • init_state
  • mixed_derivative_prod
  • update
  • update_materials
  • value_and_grad

Attributes:

  • force (Vector) –
  • indices (Index) –
  • name (str) –
  • requires_grad (frozenset[str]) –

force class-attribute instance-attribute ¤

force: Vector = field()

indices class-attribute instance-attribute ¤

indices: Index = field()

name class-attribute instance-attribute ¤

name: str = static(default=name_factory, kw_only=True)

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = static(
    default=frozenset(), kw_only=True
)

from_pyvista classmethod ¤

from_pyvista(obj: DataSet) -> Self
Source code in src/liblaf/apple/jax/energies/_force.py
22
23
24
25
26
27
@classmethod
def from_pyvista(cls, obj: pv.DataSet) -> Self:
    return cls(
        force=jnp.asarray(obj.point_data["Force"]),
        indices=jnp.asarray(obj.point_data[GLOBAL_POINT_ID]),
    )

fun ¤

fun(state: JaxEnergyState, u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/energies/_force.py
29
30
31
32
@override
@jarp.jit(inline=True)
def fun(self, state: JaxEnergyState, u: Vector) -> Scalar:
    return -jnp.vdot(self.force, u[self.indices])

grad ¤

grad(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
37
38
39
40
@jarp.jit(inline=True)
def grad(self, state: JaxEnergyState, u: Vector) -> Updates:
    values: Vector = eqx.filter_grad(jarp.partial(self.fun, state))(u)
    return values, jnp.arange(u.shape[0])

grad_and_hess_diag ¤

grad_and_hess_diag(
    state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
67
68
69
70
71
@jarp.jit(inline=True)
def grad_and_hess_diag(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]:
    return self.grad(state, u), self.hess_diag(state, u)

hess_diag ¤

hess_diag(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/energies/_force.py
34
35
36
37
@override
@jarp.jit(inline=True)
def hess_diag(self, state: JaxEnergyState, u: Vector) -> Updates:
    return jnp.zeros_like(u[self.indices]), self.indices

hess_prod ¤

hess_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
45
46
47
48
49
@jarp.jit(inline=True)
def hess_prod(self, state: JaxEnergyState, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(jarp.partial(self.fun, state)), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
51
52
53
54
55
56
@jarp.jit(inline=True)
def hess_quad(self, state: JaxEnergyState, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(state, u, p)
    return jnp.vdot(p[index], values)

init_state ¤

init_state(u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
25
26
def init_state(self, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return JaxEnergyState()

mixed_derivative_prod ¤

mixed_derivative_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/model/_energy.py
73
74
75
76
77
78
79
80
@jarp.jit(inline=True)
def mixed_derivative_prod(
    self, state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(state, u, p)
    return outputs

update ¤

update(state: JaxEnergyState, u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
28
29
def update(self, state: JaxEnergyState, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return state

update_materials ¤

update_materials(materials: EnergyMaterials) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
31
32
def update_materials(self, materials: EnergyMaterials) -> None:
    pass

value_and_grad ¤

value_and_grad(
    state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
58
59
60
61
62
63
64
65
@jarp.jit(inline=True)
def value_and_grad(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(jarp.partial(self.fun, state))(u)
    return value, (grad, jnp.arange(u.shape[0]))

QuadratureTetra ¤

Bases: Scheme


              flowchart TD
              liblaf.apple.jax.QuadratureTetra[QuadratureTetra]
              liblaf.apple.jax.fem.quadrature._scheme.Scheme[Scheme]

                              liblaf.apple.jax.fem.quadrature._scheme.Scheme --> liblaf.apple.jax.QuadratureTetra
                


              click liblaf.apple.jax.QuadratureTetra href "" "liblaf.apple.jax.QuadratureTetra"
              click liblaf.apple.jax.fem.quadrature._scheme.Scheme href "" "liblaf.apple.jax.fem.quadrature._scheme.Scheme"
            

Parameters:

  • points ¤

    (Float[Array, 'quadrature dim'], default: Array([[0.25, 0.25, 0.25]], dtype=float32) ) –
  • weights ¤

    (Float[Array, quadrature], default: Array([0.16666667], dtype=float32) ) –

Methods:

  • from_felupe
  • from_order

Attributes:

  • dim (int) –
  • n_points (int) –
  • points (Float[Array, 'quadrature dim']) –
  • weights (Float[Array, ' quadrature']) –

dim property ¤

dim: int

n_points property ¤

n_points: int

points class-attribute instance-attribute ¤

points: Float[Array, "quadrature dim"] = array(
    factory=lambda: ones((1, 3)) / 4.0
)

weights class-attribute instance-attribute ¤

weights: Float[Array, " quadrature"] = array(
    factory=lambda: ones((1,)) / 6.0
)

from_felupe classmethod ¤

from_felupe(schema: Scheme) -> Self
Source code in src/liblaf/apple/jax/fem/quadrature/_scheme.py
14
15
16
17
18
@classmethod
def from_felupe(cls, schema: felupe.quadrature.Scheme) -> Self:
    return cls(
        points=jnp.asarray(schema.points), weights=jnp.asarray(schema.weights)
    )

from_order classmethod ¤

from_order(order: int = 1) -> Self
Source code in src/liblaf/apple/jax/fem/quadrature/_tetra.py
20
21
22
@classmethod
def from_order(cls, order: int = 1) -> Self:
    return cls.from_felupe(felupe.quadrature.Tetrahedron(order=order))

Region ¤

Parameters:

  • geometry ¤

    (Geometry) –
  • quadrature ¤

    (Scheme) –
  • h ¤

    (Float[Array, 'q a'], default: None ) –
  • dhdr ¤

    (Float[Array, 'q a J'], default: None ) –
  • dXdr ¤

    (Float[Array, 'c q J J'], default: None ) –
  • drdX ¤

    (Float[Array, 'c q J J'], default: None ) –
  • dV ¤

    (Float[Array, 'c q'], default: None ) –
  • dhdX ¤

    (Float[Array, 'c q a J'], default: None ) –

Methods:

  • compute_grad
  • deformation_gradient
  • from_geometry
  • from_pyvista
  • gradient
  • integrate
  • scatter

Attributes:

  • cell_data (DataSetAttributes) –
  • cells_global (Integer[Array, 'c a']) –
  • cells_local (Integer[Array, 'c a']) –
  • dV (Float[Array, 'c q']) –
  • dXdr (Float[Array, 'c q J J']) –
  • dhdX (Float[Array, 'c q a J']) –
  • dhdr (Float[Array, 'q a J']) –
  • drdX (Float[Array, 'c q J J']) –
  • element (Element) –
  • geometry (Geometry) –
  • h (Float[Array, 'q a']) –
  • mesh (DataSet) –
  • n_cells (int) –
  • point_data (DataSetAttributes) –
  • points (Float[Array, 'p J']) –
  • quadrature (Scheme) –

cell_data property ¤

cell_data: DataSetAttributes

cells_global property ¤

cells_global: Integer[Array, 'c a']

cells_local property ¤

cells_local: Integer[Array, 'c a']

dV class-attribute instance-attribute ¤

dV: Float[Array, 'c q'] = array(default=None)

dXdr class-attribute instance-attribute ¤

dXdr: Float[Array, 'c q J J'] = array(default=None)

dhdX class-attribute instance-attribute ¤

dhdX: Float[Array, 'c q a J'] = array(default=None)

dhdr class-attribute instance-attribute ¤

dhdr: Float[Array, 'q a J'] = array(default=None)

drdX class-attribute instance-attribute ¤

drdX: Float[Array, 'c q J J'] = array(default=None)

element property ¤

element: Element

geometry class-attribute instance-attribute ¤

geometry: Geometry = field()

h class-attribute instance-attribute ¤

h: Float[Array, 'q a'] = array(default=None)

mesh property ¤

mesh: DataSet

n_cells property ¤

n_cells: int

point_data property ¤

point_data: DataSetAttributes

points property ¤

points: Float[Array, 'p J']

quadrature class-attribute instance-attribute ¤

quadrature: Scheme = field()

compute_grad ¤

compute_grad() -> None
Source code in src/liblaf/apple/jax/fem/region/_region.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def compute_grad(self) -> None:
    h: Float[Array, "q a"] = jnp.stack(
        [self.element.function(q) for q in self.quadrature.points]
    )
    dhdr: Float[Array, "q a J"] = jnp.stack(
        [self.element.gradient(q) for q in self.quadrature.points]
    )
    dXdr: Float[Array, "c q J J"] = einops.einsum(
        self.points[self.cells_local], dhdr, "c a I, q a J -> c q I J"
    )
    drdX: Float[Array, "c q J J"] = jnp.linalg.inv(dXdr)
    dV: Float[Array, "c q"] = (
        jnp.linalg.det(dXdr) * self.quadrature.weights[jnp.newaxis, :]
    )
    if jnp.any(dV <= 0):
        logger.warning("dV <= 0")
    dhdX: Float[Array, "c q a J"] = einops.einsum(
        dhdr, drdX, "q a I, c q I J -> c q a J"
    )
    self.h = h
    self.dhdr = dhdr
    self.dXdr = dXdr
    self.drdX = drdX
    self.dV = dV
    self.dhdX = dhdX

deformation_gradient ¤

deformation_gradient(
    u: Float[Array, "p J"],
) -> Float[Array, "c q J J"]
Source code in src/liblaf/apple/jax/fem/region/_region.py
110
111
112
113
114
115
def deformation_gradient(self, u: Float[Array, "p J"]) -> Float[Array, "c q J J"]:
    grad: Float[Array, "c q J J"] = self.gradient(u)
    F: Float[Array, "c q J J"] = (
        grad + jnp.identity(3)[jnp.newaxis, jnp.newaxis, ...]
    )
    return F

from_geometry classmethod ¤

from_geometry(
    geometry: Geometry,
    *,
    grad: bool = False,
    quadrature: Scheme | None = None,
) -> Self
Source code in src/liblaf/apple/jax/fem/region/_region.py
29
30
31
32
33
34
35
36
37
38
@classmethod
def from_geometry(
    cls, geometry: Geometry, *, grad: bool = False, quadrature: Scheme | None = None
) -> Self:
    if quadrature is None:
        quadrature = geometry.element.quadrature
    self: Self = cls(geometry=geometry, quadrature=quadrature)
    if grad:
        self.compute_grad()
    return self

from_pyvista classmethod ¤

from_pyvista(
    mesh: DataObject,
    *,
    grad: bool = False,
    quadrature: Scheme | None = None,
) -> Self
Source code in src/liblaf/apple/jax/fem/region/_region.py
40
41
42
43
44
45
46
47
48
49
50
@classmethod
def from_pyvista(
    cls,
    mesh: pv.DataObject,
    *,
    grad: bool = False,
    quadrature: Scheme | None = None,
) -> Self:
    geometry: Geometry = Geometry.from_pyvista(mesh)
    self: Self = cls.from_geometry(geometry, grad=grad, quadrature=quadrature)
    return self

gradient ¤

gradient(
    u: Float[Array, " points *shape"],
) -> Float[Array, "c q *shape J"]
Source code in src/liblaf/apple/jax/fem/region/_region.py
117
118
119
120
121
122
123
def gradient(
    self, u: Float[Array, " points *shape"]
) -> Float[Array, "c q *shape J"]:
    result: Float[Array, "c q *shape J"] = einops.einsum(
        self.scatter(u), self.dhdX, "c a ..., c q a J -> c q ... J"
    )
    return result

integrate ¤

integrate(
    a: Float[Array, "c q *shape"],
) -> Float[Array, " c *shape"]
Source code in src/liblaf/apple/jax/fem/region/_region.py
125
126
def integrate(self, a: Float[Array, "c q *shape"]) -> Float[Array, " c *shape"]:
    return einops.einsum(a, self.dV, "c q ..., c q -> c ...")

scatter ¤

scatter(
    u: Float[Array, " points *shape"],
) -> Float[Array, "c a *shape"]
Source code in src/liblaf/apple/jax/fem/region/_region.py
128
129
def scatter(self, u: Float[Array, " points *shape"]) -> Float[Array, "c a *shape"]:
    return u[self.cells_global]

Scheme ¤

Parameters:

  • points ¤

    (Float[Array, 'quadrature dim']) –
  • weights ¤

    (Float[Array, quadrature]) –

Methods:

  • from_felupe

Attributes:

  • dim (int) –
  • n_points (int) –
  • points (Float[Array, 'quadrature dim']) –
  • weights (Float[Array, ' quadrature']) –

dim property ¤

dim: int

n_points property ¤

n_points: int

points class-attribute instance-attribute ¤

points: Float[Array, 'quadrature dim'] = array()

weights class-attribute instance-attribute ¤

weights: Float[Array, ' quadrature'] = array()

from_felupe classmethod ¤

from_felupe(schema: Scheme) -> Self
Source code in src/liblaf/apple/jax/fem/quadrature/_scheme.py
14
15
16
17
18
@classmethod
def from_felupe(cls, schema: felupe.quadrature.Scheme) -> Self:
    return cls(
        points=jnp.asarray(schema.points), weights=jnp.asarray(schema.weights)
    )