Source code for mutopia.modalities.mode_config

from __future__ import annotations

import os
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Mapping, Sequence, Tuple, List, Optional

import xarray as xr
from mutopia.palettes import categorical_palette


[docs] class ModeConfig(ABC): """Abstract base class for modality configuration. A ``ModeConfig`` defines the coordinate system, default palettes and labels, and a small set of utilities used by modality-specific models: - dimension/coordinate descriptions (``coords``, ``sizes``, ``dims``) - access to the modality-specific TopographyModel (``TopographyModel``) - helpers to load reference components and compute context frequencies - validation and plotting helpers for signatures/observations Subclasses must implement ``coords``, ``TopographyModel`` and the data pipeline methods ``get_context_frequencies`` and ``ingest_observations``. Attributes ---------- MODE_ID : str Stable modality identifier string (e.g., "sbs"). PALETTE : list Default palette for plotting signatures for this modality. X_LABELS : list[str] Optional x-axis labels to use when plotting (defaults to context labels). DATABASE : str Relative path to a JSON file with named reference components. """ MODE_ID: Optional[str] = None PALETTE: Optional[list] = None X_LABELS: Optional[list[str]] = None DATABASE: Optional[str] = None @property def sizes(self) -> Dict[str, int]: """Sizes of each coordinate dimension for this modality. Returns ------- dict[str, int] Mapping from dimension key to number of labels. """ return {k: len(v[1]) for k, v in self.coords.items()} @property def dims(self) -> Tuple[str, ...]: """Tuple of dimension keys in the canonical order. Returns ------- tuple[str, ...] Dimension keys corresponding to ``coords`` order. """ return tuple(self.coords.keys()) @property @abstractmethod def coords(self) -> Dict[str, Tuple[str, list]]: """Coordinate spec for this modality. Returns ------- dict[str, tuple[str, list]] Mapping from key to a pair of (dimension name, labels). """ raise NotImplementedError @property @abstractmethod def TopographyModel(self): """Modality-specific TopographyModel class. Implementations may import lazily to avoid cycles. """ raise NotImplementedError @property def available_components(self) -> List[str]: """List available reference components in ``DATABASE``. Returns ------- list[str] Names of components available for ``load_components``. """ db_path = os.path.join(os.path.dirname(__file__), self.DATABASE) with open(db_path, "r") as f: database = json.load(f) return list(database.keys())
[docs] def load_components(self, *init_components: str) -> xr.DataArray: """Load named reference components from the modality database. Parameters ---------- *init_components : str Component names to load (must exist in ``available_components``). Returns ------- xarray.DataArray Array of shape (component, context) with component spectra and modality metadata attached via attributes. """ import numpy as np db_path = os.path.join(os.path.dirname(__file__), self.DATABASE) with open(db_path, "r") as f: database = json.load(f) comps = [] for component in init_components: if not component in database: raise ValueError(f"Component {component} not found in database") comps.append([database[component][l] for l in self.coords["context"][1]]) return xr.DataArray( np.array(comps, dtype=float), dims=("component", "context"), coords={ "context": ("context", self.coords["context"][1]), "component": ("component", list(init_components)), }, attrs={ "dtype": self.MODE_ID, }, )
[docs] @classmethod @abstractmethod def get_context_frequencies( cls, *, regions_file, fasta_file, ) -> xr.DataArray: """Compute per-locus context frequencies for this modality. Implementations must return an array with (configuration, context, locus) or analogous dims appropriate for the modality, using dense or sparse storage as needed. """ raise NotImplementedError
[docs] @classmethod @abstractmethod def ingest_observations( cls, *, input_file, regions_file, fasta_file, **kwargs, ) -> xr.DataArray: """Ingest raw input into a canonical observation tensor. Implementations should return an ``xarray.DataArray`` with dims (configuration, context, mutation, locus) using dense or sparse storage. The array should include a ``name`` attribute identifying the sample. """ raise NotImplementedError
def _arr_to_xr( self, dim_sizes: Mapping[str, int], coords: Any, data: Any, ) -> xr.DataArray: """Helper to build a sparse DataArray from COO inputs. Parameters ---------- dim_sizes : Mapping[str, int] Sizes for each modality dimension (excluding locus). coords : Any COO coordinates array or sequence. data : Any Values aligned with ``coords``. Returns ------- xarray.DataArray Sparse COO-backed DataArray with dims (*sizes, locus). """ from sparse import SparseArray, COO return xr.DataArray( COO( coords, data, shape=(*self.sizes.values(), dim_sizes["locus"]), ), dims=(*self.sizes.keys(), "locus"), ) @classmethod def _flatten_observations(cls, signature: xr.DataArray) -> xr.DataArray: """Normalize signature tensor shape and ordering. Ensures the ``context`` dimension is the trailing dimension and validates minimum required dims. """ cls.validate_signatures( signature, required_dims=("context",), ) signature = signature.transpose(..., "context") return signature
[docs] @classmethod def validate_signatures( cls, signature: xr.DataArray, required_dims: Sequence[str] = (), ) -> None: """Validate signature tensor structure for plotting or analysis. Parameters ---------- signature : xarray.DataArray Signature tensor. required_dims : Sequence[str] Required dimension names. """ from sparse import SparseArray for dim in required_dims: if not dim in signature.dims: raise ValueError( f"Expected dimension {dim} in signature but got {signature.dims}" ) if not len(signature.dims) <= (len(required_dims) + 1): raise ValueError( f"Expected signature to have at most {len(required_dims) + 1} dimensions " f"({', '.join(required_dims)}, +1 other), but got {', '.join(signature.dims)}" ) if isinstance(signature.data, SparseArray): signature.data = signature.data.todense()
@classmethod def _parse_signatures( cls, signature: xr.DataArray, select: Sequence[str] = (), sig_names: Optional[Sequence[str]] = None, ) -> Tuple[List[Any], Sequence[str], Optional[str]]: """Parse a signature tensor into a list of 1D signatures. Returns signatures, their names, and the leading dimension name (if any). """ if len(signature.dims) > 2: raise ValueError("`signature` must have at most 2 dimensions.") has_extra_dims = len(signature.dims) > 1 if len(select) > 0 and not has_extra_dims: raise ValueError( "`select` is only valid if the signature has one extra dimension to choose from." ) if len(select) > 0: lead_dim = signature.dims[0] select = [ feature_name for feature_name in signature[lead_dim].values if any( feature_name == q or feature_name.rsplit(":", 1)[0] == q for q in select ) ] if len(select) == 0: raise ValueError("No signatures selected. Check the `select` argument.") pl_signatures = list(signature.sel(genome_state=select).data) if sig_names and not len(select) == len(sig_names): raise ValueError( "If both `sig_names` and `select` are specified, they must have the same length." ) sig_names = sig_names or select return pl_signatures, sig_names, lead_dim elif not has_extra_dims: pl_signatures = [signature.data.ravel()] return pl_signatures, ["Signature"], None else: lead_dim = signature.dims[0] pl_signatures = list(signature.data) sig_names = signature.coords[lead_dim].values return pl_signatures, sig_names, lead_dim
[docs] @classmethod def plot( cls, signature: xr.DataArray, *select: str, palette=categorical_palette, sig_names: Optional[Sequence[str]] = None, normalize: bool = False, title: Optional[str] = None, width: float = 5.25, height: float = 1.25, ax: Any = None, label_xaxis: bool = True, **kwargs: Any, ) -> Any: """Plot one or more modality signatures as a linear bar plot. Parameters ---------- signature : xarray.DataArray Signature tensor with a trailing ``context`` dimension. May have an optional leading dimension to represent multiple signatures. *select : str Optional names to select a subset of signatures from the leading dimension. Matching is exact or by ``name:prefix`` before the last colon. palette : sequence, optional Color palette for plotting multiple signatures; defaults to the modality palette when a single signature is plotted. sig_names : Sequence[str], optional Custom names for the selected signatures; must match selection size. normalize : bool, default False If True, normalize each signature to sum to 1 before plotting. title : str, optional Axes title. width, height : float, default (5.25, 1.25) Figure sizing passed to the underlying plotting helper. ax : matplotlib.axes.Axes, optional Existing axes to draw on; if None, a new figure/axes is created. label_xaxis : bool, default True Whether to show x-axis tick labels. **kwargs Additional keyword arguments forwarded to the plotting helper. Returns ------- matplotlib.axes.Axes The axes containing the rendered plot. """ from mutopia.plot.signature_plot import _plot_linear_signature signature = cls._flatten_observations(signature) pl_signatures, sig_names, legend_title = cls._parse_signatures( signature, select=select, sig_names=sig_names, ) if len(pl_signatures) > 100: raise ValueError( f"Cannot plot more than 100 signatures at once. " f"Choose a subset of signatures to plot using the `select` argument." ) if normalize: pl_signatures = [s / s.sum() for s in pl_signatures] ax = _plot_linear_signature( ( signature.coords["context"].values if cls.X_LABELS is None else cls.X_LABELS ), palette if len(pl_signatures) > 1 else cls.PALETTE, plot_kw=kwargs, legend_title=legend_title, height=height, width=width, ax=ax, label_xaxis=label_xaxis, **dict(zip(sig_names, pl_signatures)), ) if title: ax.set_title(title) return ax
[docs] @classmethod def validate_observations( cls, locus_dim: int, observations: xr.DataArray, ) -> None: """Validate the observation tensor shape and required metadata. Ensures dims match (configuration, context, mutation, locus) and that a ``name`` attribute is present. """ if not observations.shape == (*cls.dims, locus_dim): raise ValueError( f"Expected tensor of shape {(*cls.dims, locus_dim)} but got {observations.shape}" ) if not "name" in observations.attrs: raise ValueError("Expected 'name' attribute in observations") if not observations.dims == ("configuration", "context", "mutation", "locus"): raise ValueError( "Expected dimensions ('configuration', 'context', 'mutation', 'locus')" )