Source code for mutopia.plot.track_plot.tracks

"""
Prebuilt track configurations for genome view plotting.

This module provides ready-to-use tracks built on top of the primitives in
``mutopia.plot.track_plot`` and the helpers in ``.transforms``.
"""

from __future__ import annotations
import numpy as np
from functools import partial
import mutopia.plot.track_plot as tr
from mutopia.palettes import categorical_palette, diverging_palette
from .transforms import TopographyTransformer, minmax_scale
from typing import Any, Callable, Optional, Sequence, Mapping, TYPE_CHECKING

if TYPE_CHECKING:
    from matplotlib.axes import Axes
    from mutopia.gtensor.gtensor import GTensorDataset
    from .track_plot import GenomeView

__all__ = [
    "plot_gene_annotation",
    "plot_marginal_observed_vs_expected",
    "plot_component_rates",
    "plot_topography",
    "plot_empirical_topography",
    "plot_gene_expression_track",
    "order_components",
    "plot_gene_expression_strip",
]

[docs] def plot_gene_expression_strip(feature_name: str = "GeneExpression", label=None) -> Callable[..., "Axes"]: return tr.heatmap_plot( tr.pipeline( tr.feature_matrix(feature_name), lambda x : x.expand_dims("feature"), lambda x : np.log1p(x), ), palette="Greys", ax_fn=lambda ax: ( ax.spines["top"].set_visible(False), ax.spines["right"].set_visible(False), ax.spines["left"].set_visible(False), ax.spines["bottom"].set_visible(False), ax.set_yticklabels([]), ), label=label or feature_name, )
[docs] def plot_gene_annotation( gtf: str, label: str = "Genes", all_labels_inside: bool = False, style: str = "flybase", label_genes: bool = True, fontsize: int = 5, ax_fn: Callable[["Axes"], Any] = lambda ax: ax.spines["bottom"].set_visible(False), **kw: Any, ) -> Callable[..., "Axes"]: """ Create a gene annotation track from GTF file. Parameters ---------- gtf : str Path to GTF file label : str, default "Genes" Track label all_labels_inside : bool, default False Whether to show all gene labels inside style : str, default "flybase" Gene track style fontsize : int, default 5 Font size for labels ax_fn : callable Function to customize axes appearance **kw Additional keyword arguments Returns ------- callable Track plotting function Examples -------- >>> import mutopia.plot.track_plot as tr >>> cfg = lambda v: tr.tracks.gene_annotation_track("/path/genes.gtf") >>> _ = tr.plot_view(cfg, tr.make_view(ds, region="chr1:1-2_000_000")) """ return tr.static_track( "GtfTrack", gtf, label=label, labels=label_genes, all_labels_inside=all_labels_inside, style=style, fontsize=fontsize, ax_fn=ax_fn, **kw, )
[docs] def plot_marginal_observed_vs_expected( view: "GenomeView", smooth: int = 20, pred_smooth: int = 10, label: str = "Mutation rate", legend: bool = True, height: float = 1, empirical_kw: Mapping[str, Any] = {"alpha": 0.5, "s": 0.1, "color": "lightgrey"}, predicted_kw: Mapping[str, Any] = { "color": categorical_palette[1], "dashes": (1, 1), "alpha": 0.8, "linewidth": 0.75, }, ax_fn: Callable[["Axes"], Any] = lambda ax: None, ) -> Callable[..., "Axes"]: """ Compare empirical vs. predicted marginal mutation rates. Parameters ---------- view : GenomeView Genome view that supplies smoothing and locus metadata. smooth : int, default 20 Window size for smoothing empirical rates. pred_smooth : int, default 10 Window size for smoothing predicted rates. label : str, default "Mutation rate" Track label. legend : bool, default True Whether to show a legend. height : float, default 1 Track height. empirical_kw : dict Keyword args passed to the empirical scatterplot. predicted_kw : dict Keyword args passed to the predicted line plot. ax_fn : callable Optional axes customization function. Returns ------- callable A stacked track containing empirical scatter and predicted line. Examples -------- >>> import mutopia.plot.track_plot as tr >>> view = tr.make_view(ds, region="chr1:1_000_000-1_200_000") >>> cfg = lambda v: tr.tracks.marginal_observed_vs_expected(view) >>> _ = tr.plot_view(cfg, view) """ return tr.stack_plots( tr.scatterplot( tr.pipeline( tr.select("empirical_marginal_locus"), view.smooth(smooth), tr.renorm ), **empirical_kw, ), tr.line_plot( tr.pipeline( tr.select("predicted_marginal_locus"), view.smooth(pred_smooth), tr.renorm, ), **predicted_kw, ), label=label, legend=legend, height=height, ax_fn=ax_fn, )
[docs] def plot_component_rates( view: "GenomeView", *components: Any, smooth: int = 30, height: float = 0.5, label: Optional[str] = None, color: str = categorical_palette[0], linewidth: float = 0.5, ) -> list[Callable[..., "Axes"]]: """ Plot per-component mutation rates as filled line tracks. Parameters ---------- view : GenomeView Genome view used for smoothing and renormalization. *components : Any Component identifiers to plot. smooth : int, default 30 Smoothing window (in regions). height : float, default 0.5 Track height per component. label : str, optional Label to use (defaults to the component identifier). color : str, default categorical_palette[0] Fill/line color. linewidth : float, default 0.5 Line width. Returns ------- list of callable One line_plot callable per component. """ return [ tr.line_plot( tr.pipeline( tr.select("component_distributions_locus", component=component), view.smooth(smooth), tr.renorm, ), fill=True, label=label or str(component), linewidth=linewidth, color=color, height=height, ) for component in components ]
def _topography_ax_fn(ax: "Axes", transformer: TopographyTransformer): return ( ax.grid(True, axis="y", linestyle="-", linewidth=0.4, color="white"), ax.set_yticks(np.arange(0, len(transformer.ordering_), 16), minor=False), ax.set_yticks(np.arange(8, len(transformer.ordering_), 16), minor=True), ax.set_yticklabels([]), ax.set_yticklabels(transformer.labels[::-1], minor=True, ha="right"), ax.tick_params(axis="y", which="minor", labelsize=7), ax.tick_params(axis="y", which="major", length=0), )
[docs] def plot_topography( transformer: TopographyTransformer, palette: str = "Greys", yticks: bool = False, label: str = "Predicted\ntopography", vmin: float = -3, vmax: float = 3, height: float = 1.5, **heatmap_kw: Any, ) -> Callable[..., "Axes"]: """ Heatmap of predicted topography with hierarchical row order. Parameters ---------- transformer : TopographyTransformer Fitted transformer that supplies transform, ordering, and labels. palette : str, default "Greys" Matplotlib colormap name. yticks : bool, default False Whether to render y tick labels. label : str, default "Predicted\\ntopography" Track label. vmin, vmax : float Color scale limits. height : float, default 1.5 Track height. **heatmap_kw Extra kwargs forwarded to heatmap_plot. Returns ------- callable Heatmap plotting callable. """ return tr.heatmap_plot( transformer.transform, palette=palette, yticks=yticks, label=label, vmin=vmin, vmax=vmax, height=height, ax_fn=lambda ax: _topography_ax_fn(ax, transformer), **heatmap_kw, )
[docs] def plot_empirical_topography( transformer: TopographyTransformer, palette: str = "Greys", label: str = "Empirical\ntopography", height: float = 1.5, quantile_cutoff: float = 0.999, s: float = 0.01, alpha: float = 0.5, topography_kw: Mapping[str, Any] = {"vmin": -3, "vmax": 3, "cbar": False, "alpha": 0.15}, **scatter_kw: Any, ) -> Callable[..., "Axes"]: """ Overlay empirical topography points on a predicted topography heatmap. Parameters ---------- transformer : TopographyTransformer Fitted transformer. palette : str, default "Greys" Colormap for the background heatmap. label : str, default "Empirical\\ntopography" Track label. height : float, default 1.5 Track height. quantile_cutoff : float, default 0.999 Upper quantile for clipping empirical intensities. s : float, default 0.01 Scatter point size. alpha : float, default 0.5 Scatter alpha. topography_kw : dict Extra kwargs passed to the background heatmap. **scatter_kw Extra kwargs passed to ax.scatter. Returns ------- callable Stacked scatter + heatmap track. """ from scipy.sparse import coo_matrix, csc_matrix def _get_heatmap(dataset: "GTensorDataset") -> coo_matrix: matrix = transformer._fetch_matrix("empirical_marginal", dataset) x = matrix[transformer.ordering_].values.T max_cut = np.quantile(x[np.isfinite(x)], quantile_cutoff) x = np.nan_to_num(x, nan=0.0, neginf=0.0) x = np.clip(x, a_min=0.0, a_max=max_cut) x = coo_matrix(x) x.eliminate_zeros() return x def _topography_scatter( ax: "Axes", *, dataset: "GTensorDataset", start: np.ndarray, end: np.ndarray, idx: np.ndarray, interval: tuple[int, int], **kw: Any, ) -> "Axes": matrix = _get_heatmap(dataset) matrix = coo_matrix(csc_matrix(matrix)[:, idx]) # unroll the matrix u = np.random.rand(len(matrix.data)) x = ((end - start)[matrix.col]) * u + start[matrix.col] ax.scatter(x, matrix.row, c=matrix.data, cmap="Greys", **scatter_kw, s=s) ax.set_xlim(interval) return ax return tr.stack_plots( _topography_scatter, plot_topography( transformer, palette=palette, zorder=0, **topography_kw, ), label=label, height=height, )
[docs] def plot_gene_expression_track( expression_key: str = "GeneExpression", strand_key: str = "GeneStrand", linewidth: float = 0.5, label: str = "Gene\nexpression", color: str = "lightgrey", height: float = 1, log1p: bool = True, ) -> Callable[..., "Axes"]: """ Strand-aware gene expression bar track (optionally symlog1p-transformed). Parameters ---------- expression_key : str, default "GeneExpression" Dataset feature key for expression magnitude. strand_key : str, default "GeneStrand" Dataset feature key for strand (+1 / -1 / 0). linewidth : float, default 0.5 Horizontal zero-line width. label : str, default "Gene\\nexpression" Track label. color : str, default "lightgrey" Bar color. height : float, default 1 Track height. log1p : bool, default True If True, applies symmetric log1p to signed expression. Returns ------- callable Bar plot callable. """ import numpy as np return tr.bar_plot( tr.pipeline( tr.feature_matrix(expression_key, strand_key), lambda x: np.prod(x, axis=0), lambda x: np.sign(x) * np.log1p(np.abs(x)) if log1p else x, # symlog1p ), ax_fn=lambda ax: ( ax.axhline( 0, color="k", linewidth=linewidth, ), ax.spines["bottom"].set_visible(False), ), label=label, color=color, height=height, )
[docs] def order_components(dataset: "GTensorDataset") -> np.ndarray: """ Compute an ordering of components based on hierarchical clustering. Parameters ---------- dataset : GTensorDataset Input dataset with component_distributions_locus. Returns ------- ndarray Ordered component identifiers. """ component_order = tr.pipeline( tr.select("component_distributions_locus"), lambda x : x.squeeze(), tr.apply_rows(tr.renorm), lambda x: x.to_pandas(), tr.reorder_df, )(dataset).index.values return component_order
if False: def component_rate_summary( view: "GenomeView", *, ideogram: str, scalebar_size: int = int(1e7), scalebar_scale: str = "mb", pred_smooth: int = 20, empirical_smooth: int = 10, legend: bool = True, pred_kw: Mapping[str, Any] = { "color": categorical_palette[1], "dashes": (1, 1), "alpha": 1, "linewidth": 0.75, }, component_smooth: int = 30, component_order: Optional[Sequence[Any]] = None, ) -> tuple[Any, ...]: """ Summary: scale bar, ideogram, observed vs predicted, and per-component rates. Parameters ---------- view : GenomeView Genome view for smoothing, spacing, and region info. ideogram : str Path to cytoband file for ideogram. scalebar_size : int, default 1e7 Scale bar size (bp). scalebar_scale : str, default "mb" Scale label units. pred_smooth : int, default 20 Smoothing for predicted rate line. empirical_smooth : int, default 10 Smoothing for empirical rate points. legend : bool, default True Whether to show legend in the rate plot. pred_kw : dict Predicted line kwargs. component_smooth : int, default 30 Smoothing window for component rate tracks. component_order : sequence, optional Explicit component order; computed from dataset if None. Returns ------- tuple A tuple of track callables consumable by tr.plot_view. """ component_order = ( order_components(view.dataset) if component_order is None else component_order ) return ( tr.scale_bar(scalebar_size, scale=scalebar_scale), tr.ideogram(ideogram, height=0.1), tr.tracks.plot_marginal_observed_vs_expected( view, smooth=empirical_smooth, pred_smooth=pred_smooth, predicted_kw=pred_kw, legend=legend, ), tr.spacer(0.1), *plot_component_rates(view, *component_order, smooth=component_smooth), ) def mega_summary( view, *, dataset, mutation_rate_plot_kw={}, shap_plot_kw={}, ): import seaborn as sns from mutopia.plot import plot_spectrum, plot_shap_summary from mutopia.gtensor import fetch_component feature_order = [ "GeneExpression", "H3K36me3", "POLR2A", "H3K27me3", "H3K9me3", "H3K4me1", "H3K4me3", "H3K27ac", "ATACAccessible", "DNase", "HICEigenvector", ] def _plot_shared_effects(component_name, ax): shared_effect = ( dataset.sections["Spectra"]["shared_effects"] .sel(component=component_name) .to_pandas() .iloc[1:,] .to_frame() .T ) sns.heatmap( shared_effect, cmap=diverging_palette, linewidths=0.5, vmin=-0.5, vmax=0.5, cbar=False, ax=ax, square=True, ) ax.set( ylabel="", xlabel="", xticklabels=[], ) ax.set_xticks([]) ax.set_yticks([]) ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=9) for spine in ax.spines.values(): spine.set_visible(True) spine.set_color("lightgrey") spine.set_linewidth(0.5) ax.set_yticklabels(ax.get_yticklabels(), rotation=0) def _plot_component(comp_name, ax): ax = plot_spectrum(fetch_component(dataset, comp_name), ax=ax) ax.set_ylabel( comp_name, fontsize=9, rotation=0, labelpad=10, va="center", ha="right" ) return ax def _plot_shap(comp_name, ax): defaults = dict( cbar=False, scale=50, max_size=3, alpha=0.8, ) defaults.update(shap_plot_kw) ax = plot_shap_summary( dataset, component_order=[comp_name], feature_order=feature_order + ["Repliseq" + phase for phase in phases], ax=ax, **defaults, ) ax.set( ylabel="", xlabel="", yticklabels=[], xticklabels=[], ) return ax component_order = tr.order_components(dataset) phases = ["G1b", "S1", "S2", "S3", "S4", "G2"] feature_hm = tr.heatmap_plot( tr.pipeline( tr.feature_matrix(*feature_order), tr.clip(0, 0.97), lambda x: np.log1p(x), tr.apply_rows(minmax_scale), lambda x: x.fillna(0.0), ), palette="viridis", cbar=False, label="Functional\nfeatures", ) repliseq = tr.heatmap_plot( tr.pipeline( tr.feature_matrix(*["Repliseq" + phase for phase in phases]), tr.apply_rows(minmax_scale), ), palette="crest_r", cbar=False, label="Cell cycle\nphase", ax_fn=lambda ax: (ax.set_yticklabels(reversed(phases))), ) return ( tr.columns( ..., tr.scale_bar(1_000_000, scale="mb"), ..., ..., height=0.1, ), tr.columns( ..., tr.ideogram( "/Users/allen/projects/mutopia/signaturemodels/notebooks/cytoBand.txt", height=0.1, ), ..., ..., height=0.1, ), tr.columns( ..., tr.tracks.marginal_observed_vs_expected( view, legend=False, label="Mutation\nrate", **mutation_rate_plot_kw, ), ..., ..., ), tr.columns( ..., feature_hm, ..., ..., height=1.5, ), tr.columns( ..., repliseq, ..., ..., height=0.95, ), tr.columns( tr.text_banner("Component spectra"), tr.text_banner("Component mutation rates"), tr.text_banner("Strand effects"), tr.text_banner("SHAP value summaries"), height=0.5, ), [ tr.columns( tr.custom_plot(partial(_plot_component, component_name)), tr.component_rates(view, component_name, label=" ", smooth=20)[0], tr.custom_plot(partial(_plot_shared_effects, component_name)), tr.custom_plot(partial(_plot_shap, component_name)), height=0.5, ) for component_name in component_order ], ) def plot_mega_summary(dataset, **kw): view = tr.make_view(dataset, "chr1:49007569-69616441", title=None) return tr.plot_view( mega_summary, view, dataset=dataset, **kw, width_ratios=[3, 4, 1, 4], width=11, gridpsec_kw={"hspace": 0.3, "wspace": 0.2}, )