Skip to content

array_typing.python

Functions:

as_scalar

as_scalar(x: ScalarLike) -> Scalar
Source code in src/array_typing/python/_cast.py
 4
 5
 6
 7
 8
 9
10
11
def as_scalar(x: at.ScalarLike) -> at.Scalar:
    if at.is_jax(x):
        return x.item()
    if at.is_numpy(x):
        return x.item()
    if at.is_torch(x):
        return x.item()
    return x  # pyright: ignore [reportReturnType]