Source code for opinf.lift._base

# lift/_base.py
"""Template class for lifting transformations."""

__all__ = [
    "LifterTemplate",
]

import abc
import numpy as np
import scipy.linalg as la

from .. import errors, ddt, utils


[docs] class LifterTemplate(abc.ABC): """Template class for lifting transformations. Classes that inherit from this template must implement the methods :meth:`lift()` and :meth:`unlift()`. The optional :meth:`lift_ddts` method is used by the ROM class when snapshot time derivative data are available in the native state variables. See :class:`QuadraticLifter` for an example. """ def __str__(self): """String representation: class name.""" return self.__class__.__name__ def __repr__(self): """Unique ID + string representation.""" return utils.str2repr(self)
[docs] @staticmethod @abc.abstractmethod def lift(states): # pragma: no cover """Lift the native state variables to the learning variables. Parameters ---------- states : (n, k) ndarray Native state variables. Returns ------- lifted_states : (n_new, k) ndarray Learning variables. """ raise NotImplementedError
[docs] @staticmethod def lift_ddts(states, ddts): # pragma: no cover """Lift the native state time derivatives to the time derivatives of the learning variables. Parameters ---------- states : (n, k) ndarray Native state variables. ddts : (n, k) ndarray Time derivatives of the native state variables. Each column ``ddts[:, j]`` corresponds to the state vector ``states[:, j]``. Returns ------- ddts : (n_new, k) ndarray Time derivatives of the learning variables. """ return NotImplemented
[docs] @staticmethod @abc.abstractmethod def unlift(lifted_states): # pragma: no cover """Recover the native state variables from the learning variables. Parameters ---------- lifted_states : (n, k) ndarray Learning variables. Returns ------- states : (n, k) ndarray Native state variables. """ raise NotImplementedError
# Testing -----------------------------------------------------------------
[docs] def verify(self, states, t=None, tol: float = 1e-4): r"""Verify that :meth:`lift` and :meth:`unlift` are consistent and that :meth:`lift_ddts`, if implemented, gives valid time derivatives. * The :meth:`lift` / :meth:`unlift` consistency check verifies that ``unlift(lift(states)) == states``. * The :meth:`lift_ddts` consistency check uses :meth:`opinf.ddt.ddt` to estimate the time derivatives of the states and the lifted states, then verfies that the relative difference between ``lift_ddts(states, opinf.ddt.ddt(states, t))`` and ``opinf.ddt.ddt(lift(states), t)`` is less than ``tol``. If this check fails, consider using a finer time mesh. Parameters ---------- states : (n, k) ndarray Native state variables. t : (k,) ndarray or None Time domain corresponding to the states. Only required if :meth:`lift_ddts` is implemented. tol : float > 0 Tolerance for the finite difference check of :meth:`lift_ddts`. Only used if :meth:`lift_ddts` is implemented. """ # Verify lift() and unlift() are inverses. lifted_states = self.lift(states) if (k1 := lifted_states.shape[1]) != states.shape[1]: raise errors.VerificationError( f"{k1} = lift(states).shape[1] " f"!= states.shape[1] = {states.shape[1]}" ) unlifted_states = self.unlift(lifted_states) if (shape := unlifted_states.shape) != states.shape: raise errors.VerificationError( f"{shape} = unlift(lift(states)).shape " f"!= states.shape = {states.shape}" ) if not np.allclose(unlifted_states, states): raise errors.VerificationError("unlift(lift(states)) != states") print("lift() and unlift() are consistent") # Finite difference checks for lift_ddts(). if self.lift_ddts(states, states) is NotImplemented: return if t is None: raise ValueError( "time domain 't' required for finite difference check" ) lifted_ddts = self.lift_ddts(states, ddt.ddt(states, t)) if (shape := lifted_ddts.shape) != (shape2 := lifted_states.shape): raise errors.VerificationError( f"{shape} = lift_ddts(states, ddts).shape " f"!= lift(states).shape = {shape2}" ) lddts_est = ddt.ddt(lifted_states, t) if ( diff := la.norm(lifted_ddts - lddts_est) / la.norm(lddts_est) ) > tol: raise errors.VerificationError( "lift_ddts() failed finite difference check,\n\t" "|| lift_ddts(states, d/dt[states]) - d/dt[lift(states)] || " f" / || d/dt[lift(states)] || = {diff} > tol = {tol}" ) print("lift() and lift_ddts() are consistent")