Source code for opinf.pre._base
# pre/_base.py
"""Template class for transformers."""
__all__ = [
"TransformerTemplate",
]
import abc
import numpy as np
import scipy.linalg as la
from .. import errors, ddt, utils
requires_trained = utils.requires2(
"state_dimension",
"transformer not trained, call fit() or fit_transform()",
)
[docs]
class TransformerTemplate(abc.ABC):
"""Template class for transformers.
Classes that inherit from this template must implement the methods
:meth:`fit_transform()`, :meth:`transform()`, and
:meth:`inverse_transform()`. The optional :meth:`transform_ddts()` method
is used by the ROM class when snapshot time derivative data are available
in the native state variables.
See :class:`ShiftScaleTransformer` for an example.
The default implementation of :meth:`fit()` simply calls
:meth:`fit_transform()`.
Parameters
----------
name : str
Label for the state variable that this transformer acts on.
"""
def __init__(self, name: str = None):
"""Initialize attributes."""
self.__n = None
self.__name = name
# Properties --------------------------------------------------------------
@property
def state_dimension(self):
r"""Dimension :math:`n` of the state."""
return self.__n
@state_dimension.setter
def state_dimension(self, n):
"""Set the state dimension."""
self.__n = int(n) if n is not None else None
@property
def name(self):
"""Label for the state variable that this transformer acts on."""
return self.__name
@name.setter
def name(self, label):
"""Set the state variable name."""
self.__name = str(label) if label is not None else None
def __str__(self) -> str:
"""String representation: scaling type + centering bool."""
out = [self.__class__.__name__]
if self.name is not None:
out[0] += f" for variable '{self.name}'"
if self.state_dimension is not None:
out.append(f"state_dimension: {self.state_dimension:d}")
return "\n ".join(out)
def __repr__(self) -> str:
"""Unique ID + string representation."""
return utils.str2repr(self)
def _check_shape(self, Q):
"""Verify the shape of the snapshot set Q."""
if (n := self.state_dimension) is not None and (n2 := Q.shape[0]) != n:
raise ValueError(
f"states.shape[0] = {n2:d} != {n:d} = state_dimension"
)
def _check_locs(self, locs, states_at_locs, label="states_transformed"):
"""Verify that the locs and states are aligned."""
if isinstance(locs, slice):
locs = np.arange(self.state_dimension)[locs]
if states_at_locs.shape[0] != locs.size:
raise ValueError(f"{label} not aligned with locs")
return locs
# Main routines -----------------------------------------------------------
[docs]
def fit(self, states):
"""Learn (but do not apply) the transformation.
Parameters
----------
states : (n, k) ndarray
Matrix of `k` `n`-dimensional snapshots.
Returns
-------
self
"""
self.fit_transform(states, inplace=False)
return self
[docs]
@abc.abstractmethod
def fit_transform(self, states, inplace=False):
"""Learn and apply the transformation.
Parameters
----------
states : (n, k) ndarray
Matrix of `k` `n`-dimensional snapshots.
inplace : bool
If ``True``, overwrite ``states`` during transformation.
If ``False``, create a copy of the data to transform.
Returns
-------
states_transformed: (n, k) ndarray
Matrix of `k` `n`-dimensional transformed snapshots.
"""
raise NotImplementedError # pragma: no cover
[docs]
@abc.abstractmethod
def transform(self, states, inplace=False):
"""Apply the learned transformation.
Parameters
----------
states : (n, ...) ndarray
Matrix of `n`-dimensional snapshots, or a single snapshot.
inplace : bool
If ``True``, overwrite ``states`` during transformation.
If ``False``, create a copy of the data to transform.
Returns
-------
states_transformed: (n, ...) ndarray
Matrix of `n`-dimensional transformed snapshots, or a single
transformed snapshot.
"""
raise NotImplementedError # pragma: no cover
[docs]
def transform_ddts(self, ddts, inplace=False):
r"""Apply the learned transformation to snapshot time derivatives.
If the transformation is denoted by :math:`\mathcal{T}(q)`,
this function implements :math:`\mathcal{T}'` such that
:math:`\mathcal{T}'(\ddt q) = \ddt \mathcal{T}(q)`.
Parameters
----------
ddts : (n, ...) ndarray
Matrix of `n`-dimensional snapshot time derivatives, or a
single snapshot time derivative.
inplace : bool
If True, overwrite ``ddts`` during the transformation.
If False, create a copy of the data to transform.
Returns
-------
ddts_transformed : (n, ...) ndarray
Transformed `n`-dimensional snapshot time derivatives.
"""
return NotImplemented # pragma: no cover
[docs]
@abc.abstractmethod
def inverse_transform(self, states_transformed, inplace=False, locs=None):
"""Apply the inverse of the learned transformation.
Parameters
----------
states_transformed : (n, ...) or (p, ...) ndarray
Matrix of `n`-dimensional transformed snapshots, or a single
transformed snapshot.
inplace : bool
If ``True``, overwrite ``states_transformed`` during the inverse
transformation. If ``False``, create a copy of the data to
untransform.
locs : slice or (p,) ndarray of integers or None
If given, assume ``states_transformed`` contains the transformed
snapshots at only the `p` indices described by ``locs``.
Returns
-------
states_untransformed: (n, ...) or (p, ...) ndarray
Matrix of `n`-dimensional untransformed snapshots, or the `p`
entries of such at the indices specified by ``locs``.
"""
raise NotImplementedError # pragma: no cover
# Model persistence -------------------------------------------------------
[docs]
def save(self, savefile, overwrite=False):
"""Save the transformer to an HDF5 file.
Parameters
----------
savefile : str
Path of the file to save the transformer to.
overwrite : bool
If ``True``, overwrite the file if it already exists.
Raises
------
FileExistsError
If ``overwrite=False`` but the ``savefile`` already exists.
"""
raise NotImplementedError("use pickle/joblib") # pragma: no cover
[docs]
@classmethod
def load(cls, loadfile):
"""Load a previously saved transformer from an HDF5 file.
Parameters
----------
loadfile : str
File where the transformer was stored via :meth:`save()`.
Returns
-------
transformer
"""
raise NotImplementedError("use pickle/joblib") # pragma: no cover
# Verification ------------------------------------------------------------
[docs]
def verify(self, tol: float = 1e-4):
r"""Verify that :meth:`transform()` and :meth:`inverse_transform()`
are consistent and that :meth:`transform_ddts()`, if implemented,
is consistent with :meth:`transform()`.
* The :meth:`transform()` / :meth:`inverse_transform()` consistency
check verifies that
``inverse_transform(transform(states)) == states``.
* The :meth:`transform_ddts()` consistency check uses
:meth:`opinf.ddt.ddt()` to estimate the time derivatives of the
states and the transformed states, then verfies that the relative
difference between
``transform_ddts(opinf.ddt.ddt(states, t))`` and
``opinf.ddt.ddt(transform(states), t)`` is less than ``tol``, where
``t = numpy.linspace(0, 0.1, 20)``.
Parameters
----------
tol : float > 0
Tolerance for the finite difference check of
:meth:`transform_ddts()`.
Only used if :meth:`transform_ddts()` is implemented.
"""
if (n := self.state_dimension) is None:
raise AttributeError(
"transformer not trained (state_dimension not set), "
"call fit() or fit_transform()"
)
states = np.random.random((n, 20))
# Verify transform().
states_transformed = self.transform(states, inplace=False)
if states_transformed.shape != states.shape:
raise errors.VerificationError(
"transform(states).shape != states.shape"
)
if states_transformed is states:
raise errors.VerificationError(
"transform(states, inplace=False) is states"
)
states_copy = states.copy()
states_transformed = self.transform(states_copy, inplace=True)
if states_transformed is not states_copy:
raise errors.VerificationError(
"transform(states, inplace=True) is not states"
)
# Verify inverse_transform().
states_recovered = self.inverse_transform(
states_transformed,
inplace=False,
)
if states_recovered.shape != states.shape:
raise errors.VerificationError(
"inverse_transform(transform(states)).shape != states.shape"
)
if states_recovered is states_transformed:
raise errors.VerificationError(
"inverse_transform(states_transformed, inplace=False) "
"is states_transformed"
)
states_transformed_copy = states_transformed.copy()
states_recovered = self.inverse_transform(
states_transformed_copy,
inplace=True,
)
if states_recovered is not states_transformed_copy:
raise errors.VerificationError(
"inverse_transform(states_transformed, inplace=True) "
"is not states_transformed"
)
if not np.allclose(states_recovered, states):
raise errors.VerificationError(
"transform() and inverse_transform() are not inverses"
)
self._verify_locs(states, states_transformed)
print("transform() and inverse_transform() are consistent")
# Finite difference check for transform_ddts().
if self.transform_ddts(states) is NotImplemented:
return
t = np.linspace(0, 0.1, states.shape[1])
ddts = ddt.ddt(states, t)
ddts_transformed = self.transform_ddts(ddts, inplace=False)
if ddts_transformed is ddts:
raise errors.VerificationError(
"transform_ddts(ddts, inplace=False) is ddts"
)
ddts_est = ddt.ddt(states_transformed, t)
if (
diff := la.norm(ddts_transformed - ddts_est) / la.norm(ddts_est)
) > tol:
raise errors.VerificationError(
"transform_ddts() failed finite difference check,\n\t"
"|| transform_ddts(d/dt[states]) - d/dt[transform(states)] || "
f" / || d/dt[transform(states)] || = {diff} > {tol} = tol"
)
ddts_transformed = self.transform_ddts(ddts, inplace=True)
if ddts_transformed is not ddts:
raise errors.VerificationError(
"transform_ddts(ddts, inplace=True) is not ddts"
)
print("transform() and transform_ddts() are consistent")
def _verify_locs(self, states, states_transformed):
"""Verification for inverse_transform() with locs != None"""
n = states.shape[0]
locs = np.sort(np.random.choice(n, size=(n // 3), replace=False))
states_transformed_at_locs = states_transformed[locs]
states_recovered_at_locs = self.inverse_transform(
states_transformed_at_locs,
locs=locs,
)
states_at_locs = states[locs]
if states_recovered_at_locs.shape != states_at_locs.shape:
raise errors.VerificationError(
"inverse_transform(transform(states)[locs], locs).shape "
"!= states[locs].shape"
)
if not np.allclose(states_recovered_at_locs, states_at_locs):
raise errors.VerificationError(
"transform() and inverse_transform() are not inverses "
"(locs != None)"
)