Source code for mutopia.gtensor.gtensor

"""
The GTensor module for genomic tensor analysis.
This module provides functionality for creating, manipulating, and analyzing genomic tensors,
including loading datasets, applying transformations, and generating explanations for model components.

GTensors are hierarchical, multi-dimensional arrays designed to represent complex genomic data structures.
They are sliceable along multiple dimensions, and support lazy loading for memory efficiency.

Use the Gtensor CLI tool to interact with and build GTensor datasets from the command line - the 
python API is mostly intended for analysis and visualization.
"""

from __future__ import annotations

import xarray as xr
#xr.set_options(use_new_combine_kwarg_defaults=True)
import pandas as pd
import numpy as np
from typing import Union, List, Any, Callable, Optional, Iterable, TYPE_CHECKING
from numpy.typing import NDArray
from functools import reduce
import os
from tqdm import tqdm
from mutopia.utils import logger, parse_region
from mutopia.genome_utils.bed12_utils import unstack_regions as _unstack_regions
import mutopia.gtensor.disk_interface as disk
from .interfaces import (
    CorpusInterface,
    LazySampleLoader,
    LocusSlice,
    SampleSlice,
)
if TYPE_CHECKING:
    import shap

GTensorDataset = Union[
    xr.Dataset,
    CorpusInterface,
    LazySampleLoader,
    LocusSlice,
    SampleSlice,
]

__all__ = [
    "GTensor",
    "apply_to_samples",
    "fetch_features",
    "load_dataset",
    "train_test_split",
    "lazy_load",
    "eager_load",
    "lazy_train_test_load",
    "eager_train_test_load",
    "num_sources",
    "is_mixture_dataset",
    "list_sources",
    "fetch_source",
    "get_explanation",
    "get_shap_summary",
    "equal_size_quantiles",
    "slice_regions",
    "slice_samples",
    "annot_empirical_marginal",
    "make_mixture_dataset",
    "match_dims",
    "dims_except_for",
    "unstack_regions",
    "mutate_method",
    "BED_COLS",
    "list_components",
    "fetch_component",
    "fetch_interactions",
    "fetch_shared_effects",
    "rename_components",
    "excel_report",
    "infer_source_celltypes",
]

BED_COLS = [
    "Regions/chrom",
    "Regions/start",
    "Regions/end",
]

[docs] def GTensor( modality: Any, *, name: str, chrom: List[str], start: List[int], end: List[int], context_frequencies: xr.DataArray, exposures: Union[None, NDArray[np.number]] = None, dtype: Any = None, ) -> GTensorDataset: """ Create a GTensor dataset for genomic tensor analysis. This function constructs an xarray Dataset with the standardized structure required for genomic tensor operations, including region coordinates, context frequencies, and metadata. Parameters ---------- modality : object Modality object containing coordinate information and mode configuration name : str Name identifier for the dataset chrom : List[str] List of chromosome names for each genomic region start : List[int] List of start positions for each genomic region end : List[int] List of end positions for each genomic region context_frequencies : xr.DataArray Array containing context frequency data for each region exposures : Union[None, NDArray[np.number]], optional Exposure values for each region. If None, defaults to ones dtype : optional Data type for the dataset. If None, uses modality.MODE_ID Returns ------- xr.Dataset Structured dataset with regions, coordinates, and metadata """ locus_coords = pd.Index(np.arange(len(chrom))) shared_coords = { **modality.coords, "locus": locus_coords, "sample": [], } region_lengths = np.sum( context_frequencies.data, axis=tuple(range(context_frequencies.data.ndim - 1)) ) if exposures is None: exposures = np.ones(len(locus_coords), dtype=np.float32) return xr.Dataset( { "Regions/context_frequencies": context_frequencies, "Regions/length": xr.DataArray(np.array(region_lengths), dims=("locus",)), "Regions/exposures": xr.DataArray(np.squeeze(exposures), dims=("locus",)), "Regions/chrom": xr.DataArray(np.array(chrom), dims=("locus",)), "Regions/start": xr.DataArray(np.array(start), dims=("locus",)), "Regions/end": xr.DataArray(np.array(end), dims=("locus",)), }, coords=shared_coords, attrs={ "name": name, "dtype": dtype or modality.MODE_ID, }, )
[docs] def infer_source_celltypes(dataset: GTensorDataset) -> GTensorDataset: """ Infer source cell types from feature names and assign to dataset coordinates. This function examines the feature names in the dataset to identify unique source cell types based on directory structure. It then assigns these inferred cell types to the 'source' coordinate of the dataset. Parameters ---------- dataset : GTensorDataset Input dataset containing features with potential source information Returns ------- GTensorDataset Dataset with 'source' coordinate added, reflecting inferred cell types Raises ------ ValueError If no features are found in the dataset to infer sources from """ return disk.infer_source_celltypes(dataset)
[docs] def apply_to_samples(data: GTensorDataset, func: Callable, bar: bool = True) -> GTensorDataset: """ Apply a function to each sample in a dataset with parallel processing. This function applies a given function to each sample (region) in the dataset, handling the parallelization and aggregation of results. It's designed for operations that need to process each genomic region independently. Parameters ---------- data : GTensorDataset Input dataset or data loader containing samples to process func : callable Function to apply to each sample. Should accept a dataset slice and return a result that can be concatenated bar : bool, default=True Whether to display a progress bar during processing Returns ------- GTensorDataset Dataset containing the concatenated results from all sample applications """ if not hasattr(data, "X"): data = LazySampleLoader(data) return xr.concat( [ func(data.fetch_sample(sample_name)) for sample_name in ( data.list_samples() if not bar else tqdm(data.list_samples(), desc="Applying function to samples") ) ], dim="sample", )
def mutate(func: Callable) -> Callable: """ Decorator function to modify a dataset in place. This decorator allows running mutations on a dataset without disrupting the interface chains. It wraps a function to work with the dataset's mutate method. Parameters ---------- func : callable Function that takes a dataset as first argument and returns a modified dataset Returns ------- callable Wrapped function that can be used with dataset.mutate() """ def wrapper(dataset, *args, **kwargs): return dataset.mutate(lambda x: func(x, *args, **kwargs)) return wrapper
[docs] def mutate_method(func: Callable) -> Callable: """ Decorator function to modify a dataset in place for class methods. This decorator allows running mutations on a dataset without disrupting the interface chains, specifically for methods that take 'self' as the first parameter. Parameters ---------- func : callable Method that takes self and dataset as first two arguments and returns a modified dataset Returns ------- callable Wrapped method that can be used with dataset.mutate() """ def wrapper(self, dataset, *args, **kwargs): return dataset.mutate(lambda x: func(self, x, *args, **kwargs)) return wrapper
[docs] def load_dataset( dataset: Union[str, os.PathLike], with_samples: bool = True, with_state: bool = True, ) -> GTensorDataset: """ Load a dataset from disk with configurable loading options. This function loads a dataset from disk storage. The loading behavior can be customized based on whether samples and state information should be included. Parameters ---------- dataset : str or path-like Path or identifier for the dataset to load with_samples : bool, default=True Whether to load sample data along with the dataset structure with_state : bool, default=True Whether to load state information (model parameters, etc.) Returns ------- GTensorDataset Loaded dataset interface. Returns LazySampleLoader if with_samples=False, otherwise returns CorpusInterface """ return (LazySampleLoader if not with_samples else CorpusInterface)( disk.load_dataset(dataset, with_samples=with_samples, with_state=with_state) )
[docs] def train_test_split( dataset: GTensorDataset, *test_chroms: Union[str, List[str]], lazy: bool = False ) -> tuple[GTensorDataset, GTensorDataset]: """ Split a dataset into training and testing sets based on chromosomes. This function splits the dataset by chromosomes, with specified chromosomes reserved for testing and the remainder used for training. The split can be performed eagerly (loading all data) or lazily (for memory efficiency). Parameters ---------- dataset : GTensorDataset Input dataset to split *test_chroms : Union[str, List[str]] Chromosome names to reserve for the test set. Can be provided as multiple string arguments or lists of strings lazy : bool, default=False Whether to perform lazy splitting. If True, returns LazySlicer objects that don't load data until accessed Returns ------- tuple[GTensorDataset, GTensorDataset] Training and testing dataset interfaces Raises ------ ValueError If no test chromosomes are provided or none of the specified chromosomes are found in the dataset """ if not len(test_chroms) > 0: raise ValueError("No test chromosomes provided.") test_mask = dataset.sections["Regions"].chrom.isin(test_chroms) if test_mask.sum() == 0: raise ValueError( f'None of the chromosomes in {",".join(test_chroms)} are present in the dataset. ' ) lazy = lazy or not "X" in dataset.data_vars if lazy: logger.warning( "The dataset is lazy, so the train/test split will be lazy as well. " "This may cause latency issues on systems with slow file IO." ) train = LocusSlice(dataset, locus=~test_mask) test = LocusSlice(dataset, locus=test_mask) drop_vars = ( dataset.sections.groups["Features"] + dataset.sections.groups["Regions"] ) train._base_corpus.corpus = train._base_corpus.drop_vars(drop_vars) # test._base_corpus.corpus = test._base_corpus.drop_vars(drop_vars) return train, test else: train = CorpusInterface(dataset.isel(locus=~test_mask)) test = CorpusInterface(dataset.isel(locus=test_mask)) return train, test
[docs] def lazy_load(dataset: Union[str, os.PathLike]) -> GTensorDataset: """ Load a dataset lazily without samples or state information. This is a convenience function that loads a dataset with minimal memory footprint by excluding sample data and state information. Parameters ---------- dataset : str or path-like Path or identifier for the dataset to load Returns ------- GTensorDataset Lazy dataset interface that loads data on demand """ return load_dataset(dataset, with_samples=False, with_state=False)
[docs] def eager_load(dataset: Union[str, os.PathLike]) -> GTensorDataset: """ Load a dataset eagerly with samples but without state information. This is a convenience function that loads a dataset with sample data but excludes state information for faster access patterns. Parameters ---------- dataset : str or path-like Path or identifier for the dataset to load Returns ------- GTensorDataset Eager dataset interface with samples loaded into memory """ return load_dataset(dataset, with_samples=True, with_state=False)
[docs] def lazy_train_test_load( dataset: Union[str, os.PathLike], *test_chroms: str ) -> tuple[GTensorDataset, GTensorDataset]: """ Load a dataset and perform lazy train/test split by chromosomes. This convenience function combines lazy loading with train/test splitting, providing memory-efficient access to training and testing data. Parameters ---------- dataset : str or path-like Path or identifier for the dataset to load *test_chroms : str Chromosome names to reserve for the test set Returns ------- tuple[LazySlicer, LazySlicer] Training and testing dataset slicers """ return train_test_split(lazy_load(dataset), *test_chroms, lazy=True)
[docs] def eager_train_test_load( dataset: Union[str, os.PathLike], *test_chroms: str ) -> tuple[GTensorDataset, GTensorDataset]: """ Load a dataset and perform eager train/test split by chromosomes. This convenience function combines eager loading with train/test splitting, loading all data into memory for fast access. Parameters ---------- dataset : str or path-like Path or identifier for the dataset to load *test_chroms : str Chromosome names to reserve for the test set Returns ------- tuple[CorpusInterface, CorpusInterface] Training and testing dataset interfaces """ return train_test_split(eager_load(dataset), *test_chroms, lazy=False)
[docs] def num_sources(dataset: GTensorDataset) -> int: """ Get the number of distinct sources in a dataset. This function counts the number of unique sources present in the dataset's 'source' coordinate, which is useful for determining if the dataset contains data from multiple cell types or conditions. Parameters ---------- dataset : GTensorDataset Input dataset to query for sources Returns ------- int Number of distinct sources in the dataset. Returns 0 if no sources are defined. """ return len(list_sources(dataset))
[docs] def is_mixture_dataset(dataset: GTensorDataset) -> bool: """ Check if a dataset contains data from multiple sources. This function determines whether the dataset is a mixture dataset by checking if it contains more than one source. Mixture datasets have source-specific features and require special handling for analysis. Parameters ---------- dataset : GTensorDataset Input dataset to check Returns ------- bool True if the dataset contains multiple sources, False otherwise """ return num_sources(dataset) > 1
[docs] def list_sources(dataset: GTensorDataset) -> List[str]: """ List all source identifiers in the dataset. This function extracts and returns the names of all sources present in the dataset. Sources typically represent different cell types, tissues, or experimental conditions. Parameters ---------- dataset : GTensorDataset Input dataset containing source information Returns ------- List[str] List of source names. Returns an empty list if the dataset has no 'source' coordinate defined. """ if "source" in dataset.coords: return dataset.coords["source"].values.tolist() return []
[docs] def fetch_source(dataset: GTensorDataset, source: str) -> GTensorDataset: """ Extract and restructure data for a specific source from a multi-source dataset. This function filters and reorganizes a dataset to contain only data relevant to a specified source, while maintaining shared features and state variables that are common across all sources. Parameters ---------- dataset : GTensorDataset The input dataset containing data from multiple sources, organized with hierarchical variable names (e.g., "Features/source/variable", "State/source/variable"). source : str The name of the source to extract data for. Must be present in the dataset. Returns ------- GTensorDataset A new dataset containing: - Source-specific features and state variables (with paths flattened) - Shared features and state variables (common to all sources) - Other data variables from the original dataset - Updated name attribute reflecting the source - Source dimension removed if present Raises ------ ValueError If the specified source is not found in the dataset. Notes ----- The function performs the following transformations: 1. Validates that the source exists in the dataset 2. Separates source-specific and shared variables from Features and State groups 3. Creates a rename mapping to flatten source-specific variable paths 4. Combines source-specific, shared, and other variables into a new dataset 5. Updates dataset attributes and coordinates while removing source dimension """ sources = list_sources(dataset) if not source in sources: raise ValueError(f"Source {source} not found in dataset") groups = dataset.sections.groups state = groups.pop("State", []) features = groups.pop("Features", []) use_features = [os.path.basename(v) for v in features if v.split("/")[1] == source] use_state = [os.path.basename(v) for v in state if v.split("/")[1] == source] rename_map = { os.path.join("Features", source, v): os.path.join("Features", v) for v in use_features } rename_map.update( {os.path.join("State", source, v): os.path.join("State", v) for v in use_state} ) other_dvars = [v for g in groups.values() for v in g] shared_features = [v for v in features if len(v.split("/")) == 2] shared_state = [v for v in state if len(v.split("/")) == 2] source_corpus = dataset[ list(rename_map.keys()) + other_dvars + shared_features + shared_state ].rename(rename_map) source_corpus.attrs["name"] = dataset.attrs["name"] + "/" + source if "source" in source_corpus.dims: source_corpus = source_corpus.sel(source=source, drop=True) source_corpus = source_corpus.assign_coords(**dataset.coords).drop_dims("source") return source_corpus
[docs] def get_explanation(dataset: GTensorDataset, component: str) -> "shap.Explanation": """ Generate SHAP explanations for a specific model component. This function extracts and formats SHAP values for interpretability analysis, creating an explanation object that can be used with SHAP visualization tools. Parameters ---------- dataset : GTensorDataset Dataset containing SHAP values and feature information component : str Name of the model component to explain Returns ------- shap.Explanation SHAP explanation object with values, features, and display data Raises ------ ImportError If SHAP library is not installed ValueError If the specified component doesn't have SHAP values in the dataset """ try: import shap except ImportError: raise ImportError("SHAP is required to calculate SHAP values") if not component in dataset["SHAP_values"].shap_component.values: raise ValueError( f"The dataset does not have SHAP values for component {component}." ) def _get_shap_from_source(dataset, component): shap_values = dataset["SHAP_values"] locus_dim = "locus" if "locus" in shap_values.dims else "shap_locus" shap_df = ( shap_values.sel(shap_component=component) .to_dataframe() .reset_index() .rename( columns={ "shap_component": "component", locus_dim: "locus", "shap_features": "feature", "SHAP_values": "value", } ) ) # handle this case to remove the convolution if any(shap_df.feature.str.contains(":")): shap_df[["feature", "convolution"]] = shap_df.feature.str.split( ":", expand=True, n=1 ).rename(columns={0: "feature", 1: "convolution"}) shap_df = ( shap_df.groupby(["feature", "locus"])["value"].sum().unstack().fillna(0).T ) data = ( dataset["State/locus_features"] .sel(locus=shap_df.index) .sel( feature=[ f"{s}:0" if f"{s}:0" in dataset.coords["feature"].values else s for s in shap_df.columns ] ) ).to_pandas() display_features = ( dataset.sections["Features"] .assign_coords(locus=dataset.locus.data) .sel(locus=shap_df.index) ) display_data = pd.DataFrame( [display_features[s].data for s in shap_df.columns], index=shap_df.columns, ).T return ( shap_df, data, display_data, ) if is_mixture_dataset(dataset): if "ploidy" in dataset.data_vars: dataset = dataset.drop_vars("ploidy") shap_data = [ _get_shap_from_source(fetch_source(dataset, source_name), component) for source_name in list_sources(dataset) ] else: shap_data = [_get_shap_from_source(dataset, component)] shap_df, data, display_data = [pd.concat(x) for x in zip(*shap_data)] expl = shap.Explanation( shap_df.values, feature_names=shap_df.columns, data=data.values, display_data=display_data, ) return expl
[docs] def get_shap_summary(data: GTensorDataset, source: Optional[str] = None) -> pd.DataFrame: """ Generate a summary of SHAP values for model components. This function computes summary statistics for SHAP values across all components, including effect sizes (97th percentile of absolute SHAP values) and correlations between SHAP values and feature values. This provides a high-level view of which features have the strongest associations with each component. Parameters ---------- data : GTensorDataset Dataset containing SHAP values and feature information source : str, optional Source identifier to analyze. Required if the dataset is a mixture dataset with multiple sources. Returns ------- pd.DataFrame DataFrame with columns: - component: Component name - feature: Feature name - effect_size: 97th percentile of absolute SHAP values - correlation: Pearson correlation between SHAP values and feature values Raises ------ ValueError If the dataset is a mixture dataset and no source is specified """ if is_mixture_dataset(data) and source is None: raise ValueError("Must specify source when dataset is a mixture.") source = "State/" + source if source is not None else "State" shap_values = data["SHAP_values"] locus_dim = "locus" if "locus" in shap_values.dims else "shap_locus" shap_values = ( shap_values.to_dataframe() .reset_index() .rename( columns={ "shap_component": "component", locus_dim: "locus", "shap_features": "feature", "SHAP_values": "shap_value", } ) ) shap_values = shap_values.merge( data[f"{source}/locus_features"] .sel(locus=shap_values.locus.unique()) .to_dataframe() .reset_index() .rename(columns={f"{source}/locus_features": "feature_value"}), on=["locus", "feature"], how="inner", ) try: shap_values[["feature", "window"]] = shap_values["feature"].str.split(":", expand=True, n=1) shap_values = shap_values.groupby(["component", "locus", "feature"]).agg({ "feature_value": "mean", "shap_value": "sum", }).reset_index() except ValueError: pass effect_size = ( shap_values.groupby(["component", "feature"])["shap_value"] .apply(lambda x: np.quantile(np.abs(x), 0.97)) .rename("effect_size") ) def nan_corr(x, y): """Compute correlation, ignoring NaNs.""" mask = ~np.isnan(x) & ~np.isnan(y) if np.sum(mask) < 2: return np.nan return np.corrcoef(x[mask], y[mask])[0, 1] correlation = ( shap_values.groupby(["component", "feature"])[["shap_value", "feature_value"]] .apply(lambda x: nan_corr(x["shap_value"], x["feature_value"])) .rename("correlation") ) component_summary = effect_size.to_frame().join(correlation).reset_index() return component_summary
[docs] def equal_size_quantiles( dataset: GTensorDataset, var_name: str, n_bins: int = 10, key: Optional[str] = None ) -> GTensorDataset: """ Create equal-size quantile bins for a variable in the dataset. This function bins the values of a specified variable into quantiles of equal cumulative region length, which is useful for creating balanced genomic bins. Parameters ---------- dataset : GTensorDataset Dataset containing the variable to bin var_name : str Name of the variable to create quantile bins for n_bins : int, default=10 Number of quantile bins to create key : str, optional Custom name for the output bin variable. If None, generates name as '{var_name_base}_qbins_{n_bins}' where var_name_base is the last part of var_name after splitting on '/' Returns ------- GTensorDataset The input dataset with quantile bins added as a new variable """ bin_nums = np.arange(dataset.sizes["locus"]) sorted_vals = pd.DataFrame( { "length": dataset.sections["Regions"].length.values, "value": dataset[var_name].values, }, index=bin_nums, ) sorted_vals = sorted_vals.sort_values(by="value", ascending=True) sorted_vals["cumm_fraction"] = sorted_vals["length"].cumsum() sorted_vals["cumm_fraction"] /= sorted_vals["cumm_fraction"].iloc[-1] sorted_vals["bin"] = (sorted_vals.cumm_fraction // (1 / (n_bins - 1))).astype(int) if key is None: key = f'{var_name.rsplit("/", 1)[-1]}_qbins_{n_bins}' dataset[key] = xr.DataArray( sorted_vals["bin"].loc[bin_nums].values, dims="locus", ) logger.info("Added key: " + key) return dataset
[docs] def slice_samples(dataset: GTensorDataset, samples: List[str]) -> GTensorDataset: """ Extract a subset of samples from the dataset. This function filters the dataset to include only the specified samples, enabling focused analysis on particular samples of interest while maintaining all other dataset dimensions and attributes. Parameters ---------- dataset : GTensorDataset Input dataset containing multiple samples samples : List[str] List of sample names to extract from the dataset. Sample names must exist in the dataset's sample coordinate. Returns ------- GTensorDataset Filtered dataset containing only the specified samples wrapped in a SampleSlice interface Raises ------ KeyError If any of the specified samples are not found in the dataset Notes ----- If an empty list is provided, the original dataset is returned unchanged. The function uses the mutate pattern to maintain interface chain compatibility. """ d = mutate(lambda d: d.sel(sample=list(samples)) if len(samples) > 0 else d)(dataset) return SampleSlice(d, samples)
[docs] def slice_regions( dataset: GTensorDataset, *regions: str, lazy: bool = False ) -> GTensorDataset: """ Extract genomic regions that overlap with specified intervals. This function filters the dataset to include only regions that overlap with any of the specified genomic intervals. Intervals can be specified in multiple formats: "chr:start-end", "chr" (entire chromosome), or a comma-separated list of such specifications. Parameters ---------- dataset : GTensorDataset Input dataset containing genomic regions regions : str Region specification(s) in formats: - "chr:start-end" (e.g., "chr1:1000-2000") - "chr" (entire chromosome, e.g., "chr1") - List of any of the above lazy : bool, default=False Whether to return a lazy slicer instead of materializing the data Returns ------- GTensorDataset Filtered dataset containing only overlapping regions Raises ------ ValueError If no regions match the specified query intervals """ lazy = lazy or not "X" in dataset.data_vars parsed_regions = list(map(parse_region, regions)) # Create mask for regions overlapping with any of the parsed regions ds_regions = dataset.sections["Regions"] regions_mask = np.zeros(len(ds_regions.chrom), dtype=bool) for chrom, start, end in parsed_regions: chrom_mask = ds_regions.chrom.values == chrom if np.any(chrom_mask): interval_mask = pd.IntervalIndex.from_arrays( ds_regions.start.values[chrom_mask], ds_regions.end.values[chrom_mask] ).overlaps(pd.Interval(start, end)) # Update the overall mask regions_mask[chrom_mask] |= interval_mask if not np.any(regions_mask): raise ValueError(f"No regions match the specified query: {regions}") logger.info( f"Found {np.sum(regions_mask)}/{len(regions_mask)} regions matching query." ) if lazy: return LocusSlice(dataset, locus=regions_mask) return dataset.isel(locus=regions_mask)
[docs] def annot_empirical_marginal( dataset: GTensorDataset, key: str = "empirical_marginal" ) -> GTensorDataset: """ Calculate and add empirical marginal mutation rates to a dataset. This method computes empirical marginal mutation rates by aggregating observed mutations across all samples in the dataset and normalizing by context frequencies and region lengths. Parameters ---------- dataset : GTensorDataset Dataset containing mutation data to analyze key : str, default="empirical_marginal" Base name for the mutation rate variables to be added to the dataset. Creates two variables: `{key}` and `{key}_locus` Returns ------- GTensorDataset The input dataset with empirical marginal rates added as new variables: - {key}: Marginal mutation rates normalized by context frequencies - {key}_locus: Per-locus marginal rates normalized by region length """ todense = lambda x: x.asdense() if x.is_sparse() else x coo_or_dense = lambda x: x.ascoo() if x.is_sparse() else x reduce_samples = list(dataset.list_samples()[1:]) X_emp = reduce( lambda x, y: x + y, ( coo_or_dense(sample.X) for sample in tqdm( dataset.iter_samples(subset=reduce_samples), desc="Reducing samples", total=len(reduce_samples), ) ), todense(dataset.fetch_sample(dataset.list_samples()[0]).X), ) X_emp = todense(X_emp) logger.info(f'Added key: "{key}"') dataset[key] = (X_emp / dataset.sections["Regions"].context_frequencies).fillna(0.0) locus_key = f"{key}_locus" logger.info(f'Added key: "{locus_key}"') dataset[locus_key] = ( ( X_emp.sum(dim=dims_except_for(X_emp.dims, "locus")) / dataset.sections["Regions"].length ) .fillna(0.0) .astype(np.float32) ) return dataset
[docs] def make_mixture_dataset(**datasets: GTensorDataset) -> GTensorDataset: """ Create a mixed dataset by combining multiple source datasets. This function merges multiple datasets, renaming their features and state variables to include source identifiers, enabling comparative analysis across different data sources. Parameters ---------- **datasets : GTensorDataset Named datasets to combine. Keys become source identifiers. Returns ------- GTensorDataset Combined dataset with source-specific feature namespaces """ source_names = list(datasets.keys()) merge_dsets = [] for source_name, dataset in datasets.items(): rename_map = { old_name: f"{level}/{source_name}/{os.path.basename(old_name)}" for level in ["Features", "State"] for old_name in dataset.sections.groups[level] } merge_dsets.append(dataset[rename_map.keys()].rename(rename_map)) first_dataset = list(datasets.values())[0] transfer_vars = [ var_name for level, vars in first_dataset.sections.groups.items() if not level in ["Features", "State"] for var_name in vars ] merge_dsets.append(first_dataset[transfer_vars]) merged = xr.merge(merge_dsets) merged["source"] = xr.DataArray( source_names, dims=("source",), ) merged = merged.set_coords("source") return CorpusInterface(merged)
[docs] def dims_except_for(dims: Iterable, *keepdims: str) -> tuple: return tuple({*dims}.difference({*keepdims}))
[docs] def match_dims(X: xr.DataArray, **dim_sizes: int) -> xr.DataArray: return X.expand_dims( {d: dim_sizes[d] for d in dims_except_for(dim_sizes.keys(), *X.dims)} )
def get_regions_filename(dataset: GTensorDataset) -> str: return os.path.join( os.path.dirname(dataset.attrs["filename"]), dataset.attrs["regions_file"] )
[docs] def unstack_regions(dataset: GTensorDataset) -> GTensorDataset: """ Unstack regions from a compressed format to full coordinate arrays. This function expands region data from a compact representation to full coordinate arrays, using external region file information to reconstruct chromosome, start, and end coordinates. Parameters ---------- dataset : GTensorDataset Dataset with stacked region representation Returns ------- GTensorDataset Dataset with unstacked region coordinates """ n_regions = dataset.coords["locus"].size chrom, start, end, idx = _unstack_regions( dataset.coords["locus"].values, get_regions_filename(dataset), n_regions, ) return ( dataset.drop_vars(dataset.sections.groups["Regions"]) .isel(locus=idx) .update( { "Regions/chrom": xr.DataArray(chrom, dims=("locus",)), "Regions/start": xr.DataArray(start, dims=("locus",)), "Regions/end": xr.DataArray(end, dims=("locus",)), } ) )
[docs] def fetch_features( dataset: GTensorDataset, *feature_names: str, source: Union[str, None] = None, ) -> xr.DataArray: """ Extract numerical features from the dataset's "Features" section. Parameters ---------- dataset : GTensorDataset Dataset containing feature variables under the "Features" group. *feature_names : str Glob patterns or basenames of features to select. When empty, all numeric features are returned. source : str, optional Restrict selection to features within this source directory. When None, features from all sources are considered. Returns ------- xarray.DataArray A DataArray with dims ("feature", "locus") and coords "locus", "feature" (full paths), "feature_name" (basenames), and "source". Notes ----- All selected features must share a compatible numeric dtype. """ from fnmatch import fnmatch fnames = [ name for name, _ in dataset.sections["Features"].items() if ( any( fnmatch(name, pattern) or fnmatch(os.path.basename(name), pattern) for pattern in feature_names ) or len(feature_names) == 0 ) and (source is None or fnmatch(os.path.dirname(name), source)) ] if not fnames: raise ValueError("No matching features found.") # Check that all features have inherit from same numpy data type (and so can be concatenated without unexpected type conversions) dtypes = {dataset.sections["Features"][name].dtype for name in fnames} check_dtype = lambda _dtype : all(np.issubdtype(dtype, _dtype) for dtype in dtypes) if not (check_dtype(np.number) or check_dtype(np.str_)): raise ValueError("All features must have a numeric data type.") # Reorder features to match the order in feature_names if len(feature_names) > 0: fnames = sorted( fnames, key=lambda x: ( feature_names.index(os.path.basename(x)) if os.path.basename(x) in feature_names else len(feature_names) ), ) features = [os.path.basename(name) for name in fnames] sources = [os.path.dirname(name) for name in fnames] feature_matrix = xr.DataArray( np.vstack([dataset.sections["Features"][name].values for name in fnames]), dims=("feature", "locus"), coords={ "locus": dataset.coords["locus"].values, "feature": fnames, "feature_name": ("feature", features), "source": ("feature", sources), }, name="Features", ) return feature_matrix.squeeze()
class ComponentWrapper: def __init__(self, dataset): if not "component" in dataset.coords: raise ValueError("Dataset does not contain 'component' coordinate.") self.dataset = dataset def _get_k(self, component_name): if isinstance(component_name, int): return component_name try: return list(self.dataset.coords["component"].values).index(component_name) except ValueError: raise ValueError(f"Component {component_name} not found in model.") @property def n_components(self): return len(self.dataset.coords["component"].values) @property def component_names(self): return list(self.dataset.coords["component"].values) def get_spectrum(self, idx: Union[str, int]) -> xr.DataArray: k = self._get_k(idx) return self.dataset["Spectra/spectra"].isel(component=k) def get_interactions(self, idx: Union[str, int]) -> xr.DataArray: k = self._get_k(idx) return self.dataset["Spectra/interactions"].isel(component=k) def get_shared_effects(self, idx: Union[str, int]) -> xr.DataArray: return self.dataset["Spectra/shared_effects"].isel(component=self._get_k(idx))
[docs] def rename_components(dataset: GTensorDataset, names: List[str]) -> GTensorDataset: """ Rename the components of the model and update the dataset coordinates accordingly. Parameters ---------- dataset : GTensorDataset The dataset containing model components to be renamed. names : typing.List[str] New names for the components. Must have the same length as the number of components in the model. Returns ------- GTensorDataset The dataset with updated component names in coordinates. Raises ------ ValueError If the number of provided names doesn't match the number of components. KeyError If some components in the dataset's "shap_component" coordinate don't match the model components. Notes ----- This method also updates the internal _component_names attribute of the model. """ components = ComponentWrapper(dataset) if not len(names) == components.n_components: raise ValueError("The number of names must match the number of components") name_map = dict(zip(components.component_names, names)) new_coords = {"component": names} if "shap_component" in dataset.coords: try: new_coords["shap_component"] = [ name_map[c] for c in dataset.coords["shap_component"].data ] except KeyError: raise KeyError( "Some components in dataset do not match the model components. Just delete the SHAP_values and try again." ) dataset = dataset.mutate(lambda ds : ds.assign_coords(new_coords)) return dataset
def _fetch_component_data( dataset: GTensorDataset, component_name: Union[str, int], fetch_fn ) -> xr.DataArray: components = ComponentWrapper(dataset) d = getattr(components, fetch_fn)(component_name) d.attrs["dtype"] = dataset.attrs["dtype"] return d
[docs] def list_components(dataset: GTensorDataset) -> List[str]: """ List all component names in the dataset. This function extracts and returns the names of all model components (mutational signatures or processes) present in the dataset. Parameters ---------- dataset : GTensorDataset Input dataset containing model components Returns ------- List[str] List of component names Raises ------ ValueError If the dataset does not contain a 'component' coordinate """ components = ComponentWrapper(dataset) return components.component_names
[docs] def fetch_component(dataset: GTensorDataset, component_name: Union[str, int]) -> xr.DataArray: """ Retrieve the mutational spectrum for a specific component. This function extracts the signature spectrum (mutational profile) for a specified component from the dataset. The spectrum describes the relative frequency of different mutation types for this component. Parameters ---------- dataset : GTensorDataset Dataset containing component spectra component_name : Union[str, int] Name or index of the component to retrieve Returns ------- xr.DataArray DataArray containing the component's mutational spectrum with appropriate dimensions and coordinates Raises ------ ValueError If the specified component is not found in the dataset """ return _fetch_component_data(dataset, component_name, "get_spectrum")
[docs] def fetch_interactions(dataset: GTensorDataset, component_name: Union[str, int]) -> xr.DataArray: """ Retrieve interaction effects for a specific component. This function extracts the interaction matrix for a specified component, showing how the mutational spectrum varies across different genomic contexts (e.g., strand orientation, replication timing, gene regions). Parameters ---------- dataset : GTensorDataset Dataset containing component interaction data component_name : Union[str, int] Name or index of the component to retrieve Returns ------- xr.DataArray DataArray containing the component's interaction effects with appropriate dimensions and coordinates Raises ------ ValueError If the specified component is not found in the dataset """ return _fetch_component_data(dataset, component_name, "get_interactions")
[docs] def fetch_shared_effects(dataset: GTensorDataset, component_name: Union[str, int]) -> xr.DataArray: """ Retrieve shared effects for a specific component. This function extracts the shared effects matrix for a specified component, representing effects that are common across different contexts or conditions. Shared effects capture baseline mutational patterns that don't vary with genomic features. Parameters ---------- dataset : GTensorDataset Dataset containing component shared effects data component_name : Union[str, int] Name or index of the component to retrieve Returns ------- xr.DataArray DataArray containing the component's shared effects with appropriate dimensions and coordinates Raises ------ ValueError If the specified component is not found in the dataset """ return _fetch_component_data(dataset, component_name, "get_shared_effects")
[docs] def excel_report(self, dataset: GTensorDataset, output: str, normalization="global"): """ Generate a comprehensive Excel report with model results. This method creates an Excel file containing signature data, sample contributions, and SHAP values (if available) across multiple worksheets. Parameters ---------- dataset : GTensorDataset Dataset containing the model results to export output : str Output file path for the Excel report Raises ------ ImportError If openpyxl is not installed for Excel writing support Notes ----- The Excel file will contain the following sheets: - Signature_{name}: Normalized signature data for each component - Sample_contributions: Component contributions per sample (if available) - SHAP_transformed_features: SHAP feature data (if available) - SHAP_original_features: Original feature data for SHAP (if available) - SHAP_values_{component}: SHAP values for each component (if available) Requires openpyxl to be installed: pip install openpyxl """ try: from pandas import ExcelWriter except ImportError: raise ImportError( "openpyxl is required to save excel reports, install with `pip install openpyxl`" ) renorm = lambda x: x / x.sum() * 1000 with ExcelWriter(output) as writer: for sig in self.component_names: ( renorm(self.format_component(sig, normalization=normalization)) .to_pandas() .T.to_excel( writer, sheet_name=f"Signature_{sig}", ) ) if hasattr(dataset, "contributions"): ( dataset.contributions.stack(observations=("source", "component")) .transpose("sample", ...) .to_pandas() .to_excel( writer, sheet_name="Sample_contributions", ) ) if hasattr(dataset, "SHAP_values"): shap_components = dataset.SHAP_values.coords["shap_component"].values expl = get_explanation(dataset, shap_components[0]) pd.DataFrame( expl.data, columns=expl.feature_names, ).to_excel( writer, sheet_name="SHAP_transformed_features", index=False, ) display_data = expl.display_data.copy() display_data.columns = expl.feature_names display_data.to_excel( writer, sheet_name="SHAP_original_features", index=False, ) for component in shap_components: expl = get_explanation(dataset, component) pd.DataFrame( expl.values, columns=expl.feature_names, ).to_excel( writer, sheet_name="SHAP_values_{}".format(component), index=False, )