Skip to content

liblaf.apple.jax.model ¤

Modules:

  • dirichlet

Classes:

  • Dirichlet
  • DirichletBuilder
  • JaxEnergy
  • JaxEnergyState
  • JaxModel
  • JaxModelBuilder
  • JaxModelState

Dirichlet ¤

Parameters:

  • dim ¤

    (int) –
  • dirichlet_index ¤

    (Integer[Array, dirichlet]) –
  • dirichlet_value ¤

    (Float[Array, dirichlet]) –
  • fixed_mask ¤

    (Bool[Array, 'points dim']) –
  • free_index ¤

    (Integer[Array, free]) –
  • n_points ¤

    (int) –

Methods:

  • get_fixed
  • get_free
  • set_fixed
  • set_free
  • to_full

Attributes:

  • dim (int) –
  • dirichlet_index (Integer[Array, ' dirichlet']) –
  • dirichlet_value (Float[Array, ' dirichlet']) –
  • fixed_mask (Bool[Array, 'points dim']) –
  • free_index (Integer[Array, ' free']) –
  • n_dirichlet (int) –
  • n_free (int) –
  • n_full (int) –
  • n_points (int) –

dim class-attribute instance-attribute ¤

dim: int = static()

dirichlet_index instance-attribute ¤

dirichlet_index: Integer[Array, ' dirichlet']

dirichlet_value instance-attribute ¤

dirichlet_value: Float[Array, ' dirichlet']

fixed_mask instance-attribute ¤

fixed_mask: Bool[Array, 'points dim']

free_index instance-attribute ¤

free_index: Integer[Array, ' free']

n_dirichlet property ¤

n_dirichlet: int

n_free property ¤

n_free: int

n_full property ¤

n_full: int

n_points class-attribute instance-attribute ¤

n_points: int = static()

get_fixed ¤

get_fixed(
    full: Float[Array, "points dim"],
) -> Float[Array, " dirichlet"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
27
28
29
@jarp.jit(inline=True)
def get_fixed(self, full: Float[Array, "points dim"]) -> Float[Array, " dirichlet"]:
    return full.flatten()[self.dirichlet_index]

get_free ¤

get_free(
    full: Float[Array, "points dim"],
) -> Float[Array, " free"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
31
32
33
@jarp.jit(inline=True)
def get_free(self, full: Float[Array, "points dim"]) -> Float[Array, " free"]:
    return full.flatten()[self.free_index]

set_fixed ¤

set_fixed(
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
35
36
37
38
39
40
41
42
43
@jarp.jit(inline=True)
def set_fixed(
    self,
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]:
    if values is None:
        values = self.dirichlet_value
    return full.flatten().at[self.dirichlet_index].set(values).reshape(full.shape)

set_free ¤

set_free(
    full: Float[Array, "points dim"],
    values: Float[ArrayLike, " free"],
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
45
46
47
48
49
@jarp.jit(inline=True)
def set_free(
    self, full: Float[Array, "points dim"], values: Float[ArrayLike, " free"]
) -> Float[Array, "points dim"]:
    return full.flatten().at[self.free_index].set(values).reshape(full.shape)

to_full ¤

to_full(
    free: Float[Array, " free"],
    dirichlet: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]
Source code in src/liblaf/apple/jax/model/dirichlet/_dirichlet.py
51
52
53
54
55
56
57
58
59
60
61
62
@jarp.jit(inline=True)
def to_full(
    self,
    free: Float[Array, " free"],
    dirichlet: Float[ArrayLike, " dirichlet"] | None = None,
) -> Float[Array, "points dim"]:
    full: Float[Array, "points dim"] = jnp.empty(
        (self.n_points, self.dim), free.dtype
    )
    full = self.set_free(full, free)
    full = self.set_fixed(full, dirichlet)
    return full

DirichletBuilder ¤

DirichletBuilder(dim: int = 3)

Parameters:

  • mask ¤

    (Bool[ndarray, 'points dim']) –
  • value ¤

    (Float[ndarray, 'points dim']) –

Methods:

  • add_pyvista
  • finalize
  • resize

Attributes:

  • dim (int) –
  • mask (Bool[ndarray, 'points dim']) –
  • n_points (int) –
  • value (Float[ndarray, 'points dim']) –
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
19
20
21
22
def __init__(self, dim: int = 3) -> None:
    mask: Bool[np.ndarray, "points dim"] = np.empty((0, dim), bool)
    value: Float[np.ndarray, "points dim"] = np.empty((0, dim))
    self.__attrs_init__(mask=mask, value=value)  # pyright: ignore[reportAttributeAccessIssue]

dim property ¤

dim: int

mask instance-attribute ¤

mask: Bool[ndarray, 'points dim']

n_points property ¤

n_points: int

value instance-attribute ¤

value: Float[ndarray, 'points dim']

add_pyvista ¤

add_pyvista(obj: DataSet) -> None
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
32
33
34
35
36
37
38
39
40
41
42
def add_pyvista(self, obj: pv.DataSet) -> None:
    point_id = obj.point_data[GLOBAL_POINT_ID]
    self.resize(point_id.max() + 1)
    dirichlet_mask: Bool[np.ndarray, "points dim"] = _left_broadcast_to(
        obj.point_data[DIRICHLET_MASK], (obj.n_points, self.dim)
    )
    dirichlet_value: Float[np.ndarray, "points dim"] = _left_broadcast_to(
        obj.point_data[DIRICHLET_VALUE], (obj.n_points, self.dim)
    )
    self.mask[point_id] = dirichlet_mask
    self.value[point_id] = dirichlet_value

finalize ¤

finalize() -> Dirichlet
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
44
45
46
47
48
49
50
51
52
53
54
def finalize(self) -> Dirichlet:
    mask: Bool[Array, "points dim"] = jnp.asarray(self.mask)
    dirichlet_index: Integer[Array, " dirichlet"] = jnp.flatnonzero(mask)
    return Dirichlet(
        dim=self.dim,
        dirichlet_index=dirichlet_index,
        dirichlet_value=jnp.asarray(self.value.ravel()[dirichlet_index]),
        fixed_mask=mask,
        free_index=jnp.flatnonzero(~mask),
        n_points=self.n_points,
    )

resize ¤

resize(n_points: int) -> None
Source code in src/liblaf/apple/jax/model/dirichlet/_builder.py
56
57
58
59
60
61
def resize(self, n_points: int) -> None:
    pad_after: int = n_points - self.n_points
    if pad_after <= 0:
        return
    self.mask = np.pad(self.mask, ((0, pad_after), (0, 0)), constant_values=False)
    self.value = np.pad(self.value, ((0, pad_after), (0, 0)), constant_values=0.0)

JaxEnergy ¤

Parameters:

  • name ¤

    (str, default: <dynamic> ) –
  • requires_grad ¤

    (frozenset[str], default: frozenset() ) –

Methods:

  • fun
  • grad
  • grad_and_hess_diag
  • hess_diag
  • hess_prod
  • hess_quad
  • init_state
  • mixed_derivative_prod
  • update
  • update_materials
  • value_and_grad

Attributes:

name class-attribute instance-attribute ¤

name: str = static(default=name_factory, kw_only=True)

requires_grad class-attribute instance-attribute ¤

requires_grad: frozenset[str] = static(
    default=frozenset(), kw_only=True
)

fun ¤

fun(state: JaxEnergyState, u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
34
35
def fun(self, state: JaxEnergyState, u: Vector) -> Scalar:
    raise NotImplementedError

grad ¤

grad(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
37
38
39
40
@jarp.jit(inline=True)
def grad(self, state: JaxEnergyState, u: Vector) -> Updates:
    values: Vector = eqx.filter_grad(jarp.partial(self.fun, state))(u)
    return values, jnp.arange(u.shape[0])

grad_and_hess_diag ¤

grad_and_hess_diag(
    state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
67
68
69
70
71
@jarp.jit(inline=True)
def grad_and_hess_diag(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Updates, Updates]:
    return self.grad(state, u), self.hess_diag(state, u)

hess_diag ¤

hess_diag(state: JaxEnergyState, u: Vector) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
42
43
def hess_diag(self, state: JaxEnergyState, u: Vector) -> Updates:
    raise NotImplementedError

hess_prod ¤

hess_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Updates
Source code in src/liblaf/apple/jax/model/_energy.py
45
46
47
48
49
@jarp.jit(inline=True)
def hess_prod(self, state: JaxEnergyState, u: Vector, p: Vector) -> Updates:
    values: Vector
    _, values = jax.jvp(jax.grad(jarp.partial(self.fun, state)), (u,), (p,))
    return values, jnp.arange(u.shape[0])

hess_quad ¤

hess_quad(
    state: JaxEnergyState, u: Vector, p: Vector
) -> Scalar
Source code in src/liblaf/apple/jax/model/_energy.py
51
52
53
54
55
56
@jarp.jit(inline=True)
def hess_quad(self, state: JaxEnergyState, u: Vector, p: Vector) -> Scalar:
    values: Vector
    index: Index
    values, index = self.hess_prod(state, u, p)
    return jnp.vdot(p[index], values)

init_state ¤

init_state(u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
25
26
def init_state(self, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return JaxEnergyState()

mixed_derivative_prod ¤

mixed_derivative_prod(
    state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]
Source code in src/liblaf/apple/jax/model/_energy.py
73
74
75
76
77
78
79
80
@jarp.jit(inline=True)
def mixed_derivative_prod(
    self, state: JaxEnergyState, u: Vector, p: Vector
) -> dict[str, Array]:
    outputs: dict[str, Array] = {}
    for name in self.requires_grad:
        outputs[name] = getattr(self, f"mixed_derivative_prod_{name}")(state, u, p)
    return outputs

update ¤

update(state: JaxEnergyState, u: Vector) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_energy.py
28
29
def update(self, state: JaxEnergyState, u: Vector) -> JaxEnergyState:  # noqa: ARG002
    return state

update_materials ¤

update_materials(materials: EnergyMaterials) -> None
Source code in src/liblaf/apple/jax/model/_energy.py
31
32
def update_materials(self, materials: EnergyMaterials) -> None:
    pass

value_and_grad ¤

value_and_grad(
    state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]
Source code in src/liblaf/apple/jax/model/_energy.py
58
59
60
61
62
63
64
65
@jarp.jit(inline=True)
def value_and_grad(
    self, state: JaxEnergyState, u: Vector
) -> tuple[Scalar, Updates]:
    value: Scalar
    grad: Vector
    value, grad = jax.value_and_grad(jarp.partial(self.fun, state))(u)
    return value, (grad, jnp.arange(u.shape[0]))

JaxEnergyState ¤

JaxModel ¤

Parameters:

  • energies ¤

    (dict[str, JaxEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

  • fun
  • grad
  • grad_and_hess_diag
  • hess_diag
  • hess_prod
  • hess_quad
  • init_state
  • mixed_derivative_prod
  • update
  • update_materials
  • value_and_grad

Attributes:

  • energies (dict[str, JaxEnergy]) –

energies class-attribute instance-attribute ¤

energies: dict[str, JaxEnergy] = field(factory=dict)

fun ¤

fun(state: JaxModelState, u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/model/_model.py
38
39
40
41
42
43
@jarp.jit(inline=True)
def fun(self, state: JaxModelState, u: Vector) -> Scalar:
    output: Scalar = jnp.zeros(())
    for energy in self.energies.values():
        output += energy.fun(state.data[energy.name], u)
    return output

grad ¤

grad(state: JaxModelState, u: Vector) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
45
46
47
48
49
50
51
52
53
@jarp.jit(inline=True)
def grad(self, state: JaxModelState, u: Vector) -> Vector:
    output: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        grad: Vector
        index: Index
        grad, index = energy.grad(state[energy.name], u)
        output = output.at[index].add(grad)
    return output

grad_and_hess_diag ¤

grad_and_hess_diag(
    state: JaxModelState, u: Vector
) -> tuple[Vector, Vector]
Source code in src/liblaf/apple/jax/model/_model.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
@jarp.jit(inline=True)
def grad_and_hess_diag(
    self, state: JaxModelState, u: Vector
) -> tuple[Vector, Vector]:
    grad: Vector = jnp.zeros_like(u)
    hess_diag: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        grad_i: Vector
        index_g: Index
        hess_diag_i: Vector
        index_h: Index
        (grad_i, index_g), (hess_diag_i, index_h) = energy.grad_and_hess_diag(
            state[energy.name], u
        )
        grad = grad.at[index_g].add(grad_i)
        hess_diag = hess_diag.at[index_h].add(hess_diag_i)
    return grad, hess_diag

hess_diag ¤

hess_diag(state: JaxModelState, u: Vector) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
55
56
57
58
59
60
61
62
63
@jarp.jit(inline=True)
def hess_diag(self, state: JaxModelState, u: Vector) -> Vector:
    output: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        diag: Vector
        index: Index
        diag, index = energy.hess_diag(state[energy.name], u)
        output = output.at[index].add(diag)
    return output

hess_prod ¤

hess_prod(
    state: JaxModelState, u: Vector, p: Vector
) -> Vector
Source code in src/liblaf/apple/jax/model/_model.py
65
66
67
68
69
70
71
72
73
@jarp.jit(inline=True)
def hess_prod(self, state: JaxModelState, u: Vector, p: Vector) -> Vector:
    output: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        prod: Vector
        index: Index
        prod, index = energy.hess_prod(state[energy.name], u, p)
        output = output.at[index].add(prod)
    return output

hess_quad ¤

hess_quad(
    state: JaxModelState, u: Vector, p: Vector
) -> Scalar
Source code in src/liblaf/apple/jax/model/_model.py
75
76
77
78
79
80
@jarp.jit(inline=True)
def hess_quad(self, state: JaxModelState, u: Vector, p: Vector) -> Scalar:
    output: Scalar = jnp.zeros(())
    for energy in self.energies.values():
        output += energy.hess_quad(state[energy.name], u, p)
    return output

init_state ¤

init_state(u: Vector) -> JaxModelState
Source code in src/liblaf/apple/jax/model/_model.py
22
23
24
25
26
def init_state(self, u: Vector) -> JaxModelState:
    data: dict[str, JaxEnergyState] = {}
    for energy in self.energies.values():
        data[energy.name] = energy.init_state(u)
    return JaxModelState(u=u, data=data)

mixed_derivative_prod ¤

mixed_derivative_prod(
    state: JaxModelState, u: Vector, p: Vector
) -> dict[str, dict[str, Array]]
Source code in src/liblaf/apple/jax/model/_model.py
82
83
84
85
86
87
88
89
@jarp.jit(inline=True)
def mixed_derivative_prod(
    self, state: JaxModelState, u: Vector, p: Vector
) -> dict[str, dict[str, Array]]:
    return {
        name: energy.mixed_derivative_prod(state[energy.name], u, p)
        for name, energy in self.energies.items()
    }

update ¤

update(state: JaxModelState, u: Vector) -> JaxModelState
Source code in src/liblaf/apple/jax/model/_model.py
28
29
30
31
32
def update(self, state: JaxModelState, u: Vector) -> JaxModelState:
    for energy in self.energies.values():
        energy.update(state.data[energy.name], u)
    state.u = u
    return state

update_materials ¤

update_materials(materials: ModelMaterials) -> None
Source code in src/liblaf/apple/jax/model/_model.py
34
35
36
def update_materials(self, materials: ModelMaterials) -> None:
    for energy_name, energy_materials in materials.items():
        self.energies[energy_name].update_materials(energy_materials)

value_and_grad ¤

value_and_grad(
    state: JaxModelState, u: Vector
) -> tuple[Scalar, Vector]
Source code in src/liblaf/apple/jax/model/_model.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@jarp.jit(inline=True)
def value_and_grad(self, state: JaxModelState, u: Vector) -> tuple[Scalar, Vector]:
    value: Scalar = jnp.zeros(())
    grad: Vector = jnp.zeros_like(u)
    for energy in self.energies.values():
        value_i: Scalar
        grad_i: Vector
        value_i, (grad_i, index) = energy.value_and_grad(state[energy.name], u)
        value += value_i
        grad = grad.at[index].add(grad_i)
    return value, grad

JaxModelBuilder ¤

Parameters:

  • energies ¤

    (dict[str, JaxEnergy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

  • add_energy
  • finalize

Attributes:

  • energies (dict[str, JaxEnergy]) –

energies class-attribute instance-attribute ¤

energies: dict[str, JaxEnergy] = field(
    factory=dict, kw_only=True
)

add_energy ¤

add_energy(energy: JaxEnergy) -> None
Source code in src/liblaf/apple/jax/model/_builder.py
11
12
def add_energy(self, energy: JaxEnergy) -> None:
    self.energies[energy.name] = energy

finalize ¤

finalize() -> JaxModel
Source code in src/liblaf/apple/jax/model/_builder.py
14
15
def finalize(self) -> JaxModel:
    return JaxModel(energies=self.energies)

JaxModelState ¤

Bases: MutableMapping[str, JaxEnergyState]


              flowchart TD
              liblaf.apple.jax.model.JaxModelState[JaxModelState]

              

              click liblaf.apple.jax.model.JaxModelState href "" "liblaf.apple.jax.model.JaxModelState"
            

Parameters:

  • u ¤

    (Vector) –
  • data ¤

    (dict[str, JaxEnergyState], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

  • __delitem__
  • __getitem__
  • __iter__
  • __len__
  • __setitem__

Attributes:

  • data (dict[str, JaxEnergyState]) –
  • u (Vector) –

data class-attribute instance-attribute ¤

data: dict[str, JaxEnergyState] = field(factory=dict)

u instance-attribute ¤

u: Vector

__delitem__ ¤

__delitem__(key: str) -> None
Source code in src/liblaf/apple/jax/model/_state.py
25
26
def __delitem__(self, key: str) -> None:
    del self.data[key]

__getitem__ ¤

__getitem__(key: str) -> JaxEnergyState
Source code in src/liblaf/apple/jax/model/_state.py
19
20
def __getitem__(self, key: str) -> JaxEnergyState:
    return self.data[key]

__iter__ ¤

__iter__() -> Iterator[str]
Source code in src/liblaf/apple/jax/model/_state.py
28
29
def __iter__(self) -> Iterator[str]:
    return iter(self.data)

__len__ ¤

__len__() -> int
Source code in src/liblaf/apple/jax/model/_state.py
31
32
def __len__(self) -> int:
    return len(self.data)

__setitem__ ¤

__setitem__(key: str, value: JaxEnergyState) -> None
Source code in src/liblaf/apple/jax/model/_state.py
22
23
def __setitem__(self, key: str, value: JaxEnergyState) -> None:
    self.data[key] = value