Source code for opinf.basis._linear

# basis/_linear.py
"""Linear basis class."""

__all__ = [
    "LinearBasis",
]

import warnings
import numpy as np
import scipy.sparse as sparse
import matplotlib.pyplot as plt

from .. import errors, utils
from ._base import BasisTemplate


requires_entries = utils.requires2(
    "entries",
    "basis entries not initialized",
)


[docs] class LinearBasis(BasisTemplate): r"""Linear low-dimensional state approximation. This class approximates high-dimensional states :math:`\q\in\RR^n` as a linear combination of :math:`r` basis vectors :math:`\v_1,\ldots,\v_r\in\RR^n`. The basis matrix :math:`\Vr = [~\v_1~~\cdots~~\v_r~]\in \RR^{n \times r}` and an (optional) weighting matrix :math:`\W\in\RR^{n \times n}` define the approximation. The encoding from the high-dimensional space :math:`\RR^n` to the low-dimensional space :math:`\RR^r` is given by .. math:: \q \mapsto \qhat = \Vr\trp\W\q, while the decoding from low-dimensional space :math:`\RR^r` to the high-dimensional space :math:`\RR^n` is defined as .. math:: \qhat \mapsto \breve{\q} = \Vr\qhat = \sum_{i=1}^r \hat{q}_i \v_i, where :math:`\qhat = [\hat{q}_1,\ldots,\hat{q}_r]\trp\in\RR^r`. Basis entries :math:`\Vr` and the weights :math:`\W` are specified explicitly in the constructor, not learned from state data. Parameters ---------- entries : (n, r) ndarray Basis entries :math:`\Vr\in\RR^{n\times r}`. weights : (n, n) ndarray or (n,) ndarray None Weight matrix :math:`\W\in\RR^n` or its diagonal entries. If ``None`` (default), set :math:`\W` to the identity. check_orthogonality : bool If ``True``, raise a warning if the basis is not orthogonal, i.e., if :math:`\Vr\trp\W\Vr` is not the identity. name : str or None Label for the state variable that this basis approximates. Notes ----- Pair with a :class:`opinf.pre.ShiftScaleTransformer` to do centered approximations of the form :math:`\q \approx\Vr\qhat + \bar{\q}`. """ def __init__( self, entries, weights=None, check_orthogonality: bool = True, name: str = None, ): """Initialize the basis entries.""" BasisTemplate.__init__(self, name=name) # Empty intializer for child classes (POD). if entries is None: self.__entries = None self.__weights = weights return # Set the entries. self.__entries = entries BasisTemplate.full_state_dimension.fset(self, entries.shape[0]) BasisTemplate.reduced_state_dimension.fset(self, entries.shape[1]) # Set the weights. if weights is not None: if (dim := np.ndim(weights)) == 1: n = np.size(weights) weights = sparse.dia_array(([weights], [0]), shape=(n, n)) elif dim != 2: raise ValueError("expected one- or two-dimensional weights") self.__weights = weights # Verify orthogonality if desired. if check_orthogonality: V, W = self.entries, self.weights Id = (V.T @ V) if W is None else (V.T @ W @ V) if not np.allclose(Id, np.eye(self.reduced_state_dimension)): warnings.warn("basis not orthogonal", errors.OpInfWarning) # Properties -------------------------------------------------------------- @property def entries(self): r"""Entries of the basis matrix :math:`\Vr\in\RR^{n \times r}`. Also accessible via indexing (``basis[:]``). """ return self.__entries @property def weights(self) -> np.ndarray: r"""Weight matrix :math:`\W \in \RR^{n \times n}`.""" return self.__weights @property def full_state_dimension(self): r"""Dimension :math:`n` of the full state.""" return BasisTemplate.full_state_dimension.fget(self) @property def reduced_state_dimension(self): r"""Dimension :math:`r` of the reduced (compressed) state.""" return BasisTemplate.reduced_state_dimension.fget(self) @requires_entries def __getitem__(self, key): """self[:] --> self.entries.""" return self.entries[key]
[docs] def fit(self, *args, **kwargs): """Do nothing, the basis entries are set in the constructor.""" return self
# Dimension reduction -----------------------------------------------------
[docs] @requires_entries def compress(self, state: np.ndarray) -> np.ndarray: r"""Map high-dimensional states to low-dimensional latent coordinates. .. math:: \q \mapsto \qhat = \Vr\trp\q. If a weight matrix :math:`\W` is present, the compression is .. math:: \q \mapsto \qhat = \Vr\trp\W\q. Parameters ---------- states : (n, ...) ndarray Matrix of `n`-dimensional state vectors, or a single state vector. Returns ------- states_compressed : (r, ...) ndarray Matrix of `r`-dimensional latent coordinate vectors, or a single coordinate vector. """ if self.weights is not None: state = self.weights @ state return self.entries.T @ state
[docs] @requires_entries def decompress( self, states_compressed: np.ndarray, locs=None, ) -> np.ndarray: r"""Map low-dimensional latent coordinates to high-dimensional states. .. math:: \qhat \mapsto \breve{\q} = \Vr\qhat Parameters ---------- states_compressed : (r, ...) ndarray Matrix of `r`-dimensional latent coordinate vectors, or a single coordinate vector. locs : slice or (p,) ndarray of integers or None If given, return the decompressed state at only the `p` specified locations (indices) described by ``locs``. Returns ------- states_decompressed : (n, ...) or (p, ...) ndarray Matrix of `n`-dimensional decompressed state vectors, or the `p` entries of such at the entries specified by ``locs``. """ Vr = self.entries if locs is None else self.entries[locs] return Vr @ states_compressed
# Visualizations ----------------------------------------------------------
[docs] @requires_entries def plot1D(self, x=None, num_vectors=None, ax=None, **kwargs): """Plot the basis vectors over a one-dimensional domain. Parameters ---------- x : (n,) ndarray or None One-dimensional spatial domain over which to plot the vectors. Defaults to [0, 1] with `n` points. num_vectors : int or None Number of basis vectors to plot. If ``None`` (default), plot all basis vectors. ax : plt.Axes or None Matplotlib Axes to plot on. If ``None`` (default), a new figure is created. kwargs : dict Other keyword arguments to pass to ``plt.plot()``. Returns ------- ax : plt.Axes Matplotlib Axes for the plot. """ if x is None: x = np.linspace(0, 1, self.full_state_dimension) if num_vectors is None: num_vectors = self.reduced_state_dimension num_vectors = min(num_vectors, self.reduced_state_dimension) if ax is None: ax = plt.figure().add_subplot(111) for j in range(num_vectors): ax.plot(x, self.entries[:, j], **kwargs) ax.set_xlim(x[0], x[-1]) ax.set_xlabel("spatial domain") ax.set_ylabel("basis vectors") return ax
# Persistence ------------------------------------------------------------- def __eq__(self, other) -> bool: """Two LinearBasis objects are equal if their type, dimensions, and basis entries are the same. """ if not isinstance(other, self.__class__): return False if self.shape != other.shape: return False if self.weights is None and other.weights is not None: return False if (w1 := self.weights) is not None: if (w2 := other.weights) is None: return False if sparse.issparse(w1) and sparse.issparse(w2): w1, w2 = w1.data[0], w2.data[0] if not np.allclose(w1, w2): return False return np.all(self.entries == other.entries)
[docs] def save(self, savefile: str, overwrite: bool = False): """Save the basis to an HDF5 file. Parameters ---------- savefile : str Path of the file to save the basis to. overwrite : bool If ``True``, overwrite the file if it already exists. If ``False`` (default), raise a ``FileExistsError`` if the file already exists. """ with utils.hdf5_savehandle(savefile, overwrite) as hf: if self.name: meta = hf.create_dataset("meta", shape=(0,)) meta.attrs["name"] = self.name hf.create_dataset("entries", data=self.entries) if (w := self.weights) is not None: if isinstance(w, sparse.dia_array): w = w.data[0] hf.create_dataset("weights", data=w)
[docs] @classmethod def load(cls, loadfile: str): """Load a basis from an HDF5 file. Parameters ---------- loadfile : str Path to the file where the basis was stored via :meth:`save`. Returns ------- LinearBasis """ with utils.hdf5_loadhandle(loadfile) as hf: name = None if "meta" in hf: name = hf["meta"].attrs["name"] entries = hf["entries"][:] weights = hf["weights"][:] if "weights" in hf else None return cls(entries, weights, name=name)