Skip to content

Documentation for parametrix

parametrix.Param

Bases: Module, Generic[T]

Base class for a parameter.

Source code in parametrix/__init__.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class Param(eqx.Module, Generic[T]):
    """Base class for a parameter."""

    raw_value: jax.Array
    """The raw, stored value of the parameter."""

    def __init__(self, value: T | jax.Array | np.ndarray[Any, Any]) -> None:
        """Initialize the parameter.

        **Arguments:**

        - `value`: The value of the parameter.
        """
        self.raw_value = jnp.asarray(value)

    @property
    def value(self) -> jax.Array:
        """The value of the parameter.

        Subclasses can override this property to return any other value
        computed from `raw_value`.
        """
        return self.raw_value

    def __jax_array__(self) -> jax.Array:
        return self.value

    def __getitem__(self, key: Any) -> Any:
        return self.value[key]

    def __len__(self) -> int:
        return len(self.value)

    def __iter__(self) -> Iterator[Any]:
        return iter(self.value)

    def __contains__(self, item: object) -> bool:
        return item in self.value

    def __add__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__add__(other)

    def __sub__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__sub__(other)

    def __mul__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__mul__(other)

    def __matmul__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__matmul__(other)

    def __truediv__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__truediv__(other)

    def __floordiv__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__floordiv__(other)

    def __mod__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__mod__(other)

    def __divmod__(self, other: Any) -> tuple[jax.Array, jax.Array]:
        if isinstance(other, Param):
            other = other.value
        return self.value.__divmod__(other)

    def __pow__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__pow__(other)

    def __radd__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__radd__(other)

    def __rsub__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rsub__(other)

    def __rmul__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rmul__(other)

    def __rmatmul__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rmatmul__(other)

    def __rtruediv__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rtruediv__(other)

    def __rfloordiv__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rfloordiv__(other)

    def __rmod__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rmod__(other)

    def __rdivmod__(self, other: Any) -> tuple[jax.Array, jax.Array]:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rdivmod__(other)  # type: ignore[return-value]

    def __rpow__(self, other: Any) -> jax.Array:
        if isinstance(other, Param):
            other = other.value
        return self.value.__rpow__(other)

    def __neg__(self) -> jax.Array:
        return self.value.__neg__()

    def __pos__(self) -> jax.Array:
        return self.value.__pos__()

    def __abs__(self) -> jax.Array:
        return self.value.__abs__()

    def __invert__(self) -> jax.Array:
        return self.value.__invert__()

    def __complex__(self) -> complex:
        return self.value.__complex__()

    def __int__(self) -> int:
        return self.value.__int__()

    def __float__(self) -> float:
        return self.value.__float__()

    def __index__(self) -> int:
        return self.value.__index__()

    def __round__(self, ndigits: int) -> jax.Array:
        return self.value.__round__(ndigits)

raw_value: jax.Array = jnp.asarray(value) instance-attribute

The raw, stored value of the parameter.

value: jax.Array property

The value of the parameter.

Subclasses can override this property to return any other value computed from raw_value.

__abs__() -> jax.Array

Source code in parametrix/__init__.py
149
150
def __abs__(self) -> jax.Array:
    return self.value.__abs__()

__add__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
53
54
55
56
def __add__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__add__(other)

__complex__() -> complex

Source code in parametrix/__init__.py
155
156
def __complex__(self) -> complex:
    return self.value.__complex__()

__contains__(item: object) -> bool

Source code in parametrix/__init__.py
50
51
def __contains__(self, item: object) -> bool:
    return item in self.value

__divmod__(other: Any) -> tuple[jax.Array, jax.Array]

Source code in parametrix/__init__.py
88
89
90
91
def __divmod__(self, other: Any) -> tuple[jax.Array, jax.Array]:
    if isinstance(other, Param):
        other = other.value
    return self.value.__divmod__(other)

__float__() -> float

Source code in parametrix/__init__.py
161
162
def __float__(self) -> float:
    return self.value.__float__()

__floordiv__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
78
79
80
81
def __floordiv__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__floordiv__(other)

__getitem__(key: Any) -> Any

Source code in parametrix/__init__.py
41
42
def __getitem__(self, key: Any) -> Any:
    return self.value[key]

__index__() -> int

Source code in parametrix/__init__.py
164
165
def __index__(self) -> int:
    return self.value.__index__()

__init__(value: T | jax.Array | np.ndarray[Any, Any]) -> None

Initialize the parameter.

Arguments:

  • value: The value of the parameter.
Source code in parametrix/__init__.py
20
21
22
23
24
25
26
27
def __init__(self, value: T | jax.Array | np.ndarray[Any, Any]) -> None:
    """Initialize the parameter.

    **Arguments:**

    - `value`: The value of the parameter.
    """
    self.raw_value = jnp.asarray(value)

__int__() -> int

Source code in parametrix/__init__.py
158
159
def __int__(self) -> int:
    return self.value.__int__()

__invert__() -> jax.Array

Source code in parametrix/__init__.py
152
153
def __invert__(self) -> jax.Array:
    return self.value.__invert__()

__iter__() -> Iterator[Any]

Source code in parametrix/__init__.py
47
48
def __iter__(self) -> Iterator[Any]:
    return iter(self.value)

__jax_array__() -> jax.Array

Source code in parametrix/__init__.py
38
39
def __jax_array__(self) -> jax.Array:
    return self.value

__len__() -> int

Source code in parametrix/__init__.py
44
45
def __len__(self) -> int:
    return len(self.value)

__matmul__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
68
69
70
71
def __matmul__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__matmul__(other)

__mod__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
83
84
85
86
def __mod__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__mod__(other)

__mul__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
63
64
65
66
def __mul__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__mul__(other)

__neg__() -> jax.Array

Source code in parametrix/__init__.py
143
144
def __neg__(self) -> jax.Array:
    return self.value.__neg__()

__pos__() -> jax.Array

Source code in parametrix/__init__.py
146
147
def __pos__(self) -> jax.Array:
    return self.value.__pos__()

__pow__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
93
94
95
96
def __pow__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__pow__(other)

__radd__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
 98
 99
100
101
def __radd__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__radd__(other)

__rdivmod__(other: Any) -> tuple[jax.Array, jax.Array]

Source code in parametrix/__init__.py
133
134
135
136
def __rdivmod__(self, other: Any) -> tuple[jax.Array, jax.Array]:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rdivmod__(other)  # type: ignore[return-value]

__rfloordiv__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
123
124
125
126
def __rfloordiv__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rfloordiv__(other)

__rmatmul__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
113
114
115
116
def __rmatmul__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rmatmul__(other)

__rmod__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
128
129
130
131
def __rmod__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rmod__(other)

__rmul__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
108
109
110
111
def __rmul__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rmul__(other)

__round__(ndigits: int) -> jax.Array

Source code in parametrix/__init__.py
167
168
def __round__(self, ndigits: int) -> jax.Array:
    return self.value.__round__(ndigits)

__rpow__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
138
139
140
141
def __rpow__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rpow__(other)

__rsub__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
103
104
105
106
def __rsub__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rsub__(other)

__rtruediv__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
118
119
120
121
def __rtruediv__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__rtruediv__(other)

__sub__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
58
59
60
61
def __sub__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__sub__(other)

__truediv__(other: Any) -> jax.Array

Source code in parametrix/__init__.py
73
74
75
76
def __truediv__(self, other: Any) -> jax.Array:
    if isinstance(other, Param):
        other = other.value
    return self.value.__truediv__(other)