Skip to content

liblaf.apple.jax.sim.model ¤

Classes:

Model ¤

Parameters:

  • energies (Mapping[str, Energy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

Attributes:

energies class-attribute instance-attribute ¤

energies: Mapping[str, Energy] = field(factory=dict)

fun ¤

fun(u: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/model/_model.py
16
17
18
19
20
def fun(self, u: Vector) -> Scalar:
    if not self.energies:
        return jnp.zeros((), u.dtype)
    outputs: list[Scalar] = [energy.fun(u) for energy in self.energies.values()]
    return jnp.sum(jnp.asarray(outputs))

fun_and_jac ¤

fun_and_jac(u: Vector) -> tuple[Scalar, Vector]
Source code in src/liblaf/apple/jax/sim/model/_model.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def fun_and_jac(self, u: Vector) -> tuple[Scalar, Vector]:
    if not self.energies:
        return jnp.zeros((), u.dtype), jnp.zeros_like(u)
    value_list: list[Scalar] = []
    updates_data_list: list[UpdatesData] = []
    updates_index_list: list[UpdatesIndex] = []
    for energy in self.energies.values():
        fun, (data, index) = energy.fun_and_jac(u)
        value_list.append(fun)
        updates_data_list.append(data)
        updates_index_list.append(index)
    fun: Scalar = jnp.sum(jnp.asarray(value_list))
    jac: Vector = jax.ops.segment_sum(
        jnp.concat(updates_data_list),
        jnp.concat(updates_index_list),
        num_segments=u.shape[0],
    )
    return fun, jac

hess_diag ¤

hess_diag(u: Vector) -> Vector
Source code in src/liblaf/apple/jax/sim/model/_model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def hess_diag(self, u: Vector) -> Vector:
    if not self.energies:
        return jnp.zeros_like(u)
    updates_data_list: list[UpdatesData] = []
    updates_index_list: list[UpdatesIndex] = []
    for energy in self.energies.values():
        data: UpdatesData
        index: UpdatesIndex
        data, index = energy.hess_diag(u)
        updates_data_list.append(data)
        updates_index_list.append(index)
    hess_diag: Vector = jax.ops.segment_sum(
        jnp.concat(updates_data_list),
        jnp.concat(updates_index_list),
        num_segments=u.shape[0],
    )
    return hess_diag

hess_prod ¤

hess_prod(u: Vector, p: Vector) -> Vector
Source code in src/liblaf/apple/jax/sim/model/_model.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def hess_prod(self, u: Vector, p: Vector) -> Vector:
    if not self.energies:
        return jnp.zeros_like(u)
    updates_data_list: list[UpdatesData] = []
    updates_index_list: list[UpdatesIndex] = []
    for energy in self.energies.values():
        data: UpdatesData
        index: UpdatesIndex
        data, index = energy.hess_prod(u, p)
        updates_data_list.append(data)
        updates_index_list.append(index)
    hess_prod: Vector = jax.ops.segment_sum(
        jnp.concat(updates_data_list),
        jnp.concat(updates_index_list),
        num_segments=u.shape[0],
    )
    return hess_prod

hess_quad ¤

hess_quad(u: Vector, p: Vector) -> Scalar
Source code in src/liblaf/apple/jax/sim/model/_model.py
76
77
78
79
80
81
82
def hess_quad(self, u: Vector, p: Vector) -> Scalar:
    if not self.energies:
        return jnp.zeros((), u.dtype)
    outputs: list[Scalar] = [
        energy.hess_quad(u, p) for energy in self.energies.values()
    ]
    return jnp.sum(jnp.asarray(outputs))

jac ¤

jac(u: Vector) -> Vector
Source code in src/liblaf/apple/jax/sim/model/_model.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def jac(self, u: Vector) -> Vector:
    if not self.energies:
        return jnp.zeros_like(u)
    updates_data_list: list[UpdatesData] = []
    updates_index_list: list[UpdatesIndex] = []
    for energy in self.energies.values():
        data: UpdatesData
        index: UpdatesIndex
        data, index = energy.jac(u)
        updates_data_list.append(data)
        updates_index_list.append(index)
    jac: Vector = jax.ops.segment_sum(
        jnp.concat(updates_data_list),
        jnp.concat(updates_index_list),
        num_segments=u.shape[0],
    )
    return jac

jac_and_hess_diag ¤

jac_and_hess_diag(u: Vector) -> tuple[Vector, Vector]
Source code in src/liblaf/apple/jax/sim/model/_model.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def jac_and_hess_diag(self, u: Vector) -> tuple[Vector, Vector]:
    if not self.energies:
        return jnp.zeros_like(u), jnp.zeros_like(u)
    updates_data_jac_list: list[UpdatesData] = []
    updates_index_jac_list: list[UpdatesIndex] = []
    updates_data_hess_list: list[UpdatesData] = []
    updates_index_hess_list: list[UpdatesIndex] = []
    for energy in self.energies.values():
        (data_jac, index_jac), (data_hess, index_hess) = energy.jac_and_hess_diag(u)
        updates_data_jac_list.append(data_jac)
        updates_index_jac_list.append(index_jac)
        updates_data_hess_list.append(data_hess)
        updates_index_hess_list.append(index_hess)
    jac: Vector = jax.ops.segment_sum(
        jnp.concat(updates_data_jac_list),
        jnp.concat(updates_index_jac_list),
        num_segments=u.shape[0],
    )
    hess_diag: Vector = jax.ops.segment_sum(
        jnp.concat(updates_data_hess_list),
        jnp.concat(updates_index_hess_list),
        num_segments=u.shape[0],
    )
    return jac, hess_diag

mixed_derivative_prod ¤

mixed_derivative_prod(
    u: Vector, p: Vector
) -> dict[str, dict[str, Array]]
Source code in src/liblaf/apple/jax/sim/model/_model.py
128
129
130
131
132
133
134
135
def mixed_derivative_prod(
    self, u: Vector, p: Vector
) -> dict[str, dict[str, Array]]:
    outputs: dict[str, dict[str, Array]] = {
        energy.id: energy.mixed_derivative_prod(u, p)
        for energy in self.energies.values()
    }
    return outputs

ModelBuilder ¤

Parameters:

  • dirichlet (DirichletBuilder, default: <dynamic> ) –
  • energies (dict[str, Energy], default: <class 'dict'> ) –

    dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

  • points (Float[Array, 'p J'], default: Array([], shape=(0, 3), dtype=float32) ) –

Methods:

Attributes:

dirichlet class-attribute instance-attribute ¤

dirichlet: DirichletBuilder = field(
    factory=DirichletBuilder
)

energies class-attribute instance-attribute ¤

energies: dict[str, Energy] = field(factory=dict)

n_points property ¤

n_points: int

points class-attribute instance-attribute ¤

points: Float[Array, "p J"] = array(factory=_default_points)

add_dirichlet ¤

add_dirichlet(mesh: DataSet) -> None
Source code in src/liblaf/apple/jax/sim/model/_model_builder.py
27
28
def add_dirichlet(self, mesh: pv.DataSet) -> None:
    self.dirichlet.add(mesh)

add_energy ¤

add_energy(energy: Energy) -> None
Source code in src/liblaf/apple/jax/sim/model/_model_builder.py
30
31
def add_energy(self, energy: Energy) -> None:
    self.energies[energy.id] = energy

assign_dofs ¤

assign_dofs(mesh: T) -> T
Source code in src/liblaf/apple/jax/sim/model/_model_builder.py
33
34
35
36
37
38
39
def assign_dofs[T: pv.DataSet](self, mesh: T) -> T:
    mesh.point_data["point-ids"] = np.arange(
        self.n_points, self.n_points + mesh.n_points
    )
    self.points = jnp.concat([self.points, mesh.points])
    self.dirichlet.resize(self.n_points)
    return mesh

finish ¤

finish() -> Model
Source code in src/liblaf/apple/jax/sim/model/_model_builder.py
41
42
def finish(self) -> Model:
    return Model(energies=self.energies)