Source code for mutopia.plot.report

from __future__ import annotations

from typing import Optional, TYPE_CHECKING

if TYPE_CHECKING:
    from matplotlib.figure import Figure
    from mutopia.gtensor.gtensor import GTensorDataset


[docs] def plot_signature_report( dataset: "GTensorDataset", component, width: float = 5.25, height: float = 2.0, show: bool = True, bubble_scale: float = 300, ) -> Optional["Figure"]: """ Generate a comprehensive report for a specific signature component. This method creates a figure with signature plots for mesoscale states and an interaction matrix for the specified component, providing a visual representation of the signature's characteristics. Parameters ---------- dataset : GTensorDataset Dataset containing the signature data. component : int or str The signature component to visualize. Can be an integer index or a string identifier. width : float, default=5.25 The base width of the figure in inches. The actual figure width may be adjusted based on the number of states. height : float, default=2.0 The base height per signature group in inches. show : bool, default=True Whether to display the figure immediately. bubble_scale : float, default=300 Scale parameter passed to the SHAP summary bubble sizes. Returns ------- matplotlib.figure.Figure or None The generated figure when ``show=False``; otherwise ``None``. Notes ----- The report organizes mesoscale states into groups based on their prefix (before the colon), and displays them in separate rows. For singleton state groups (except Baseline), the Baseline state is automatically added as a reference. """ import matplotlib.pyplot as plt from collections import defaultdict from .signature_plot import plot_spectrum from .coef_matrix_plot import plot_interaction_matrix from ..gtensor import fetch_component from .bubble_plot import plot_shap_summary signatures = fetch_component(dataset, component) n_rows = len(signatures.genome_state) state_groups = defaultdict(list) for state in signatures.genome_state.values: state_groups[state.split(":")[0]].append(state) # for k, v in state_groups.items(): # if not k == "Baseline" and len(v) == 1: # state_groups[k].append("Baseline") max_n_states = max(map(len, state_groups.values())) n_sigs = len(state_groups) fig = plt.figure(figsize=(max(width * max_n_states, 10), height * n_sigs + 3)) gs = fig.add_gridspec( 3, 1, height_ratios=[height * n_sigs + 2, 1.5, 2.5 + 0.35 * n_rows], hspace=0.35, ) gs0 = gs[0].subgridspec( n_sigs + 1, max_n_states, hspace=0.75, wspace=0.5, width_ratios=[3] + [1] * (max_n_states - 1), ) for i, states in enumerate(state_groups.values()): ax = fig.add_subplot(gs0[i, : len(states)]) plot_spectrum( signatures, *states, ax=ax, ) shap_ax = fig.add_subplot(gs[1, 0]) plot_shap_summary( dataset, component_order=[component], ax=shap_ax, scale=bubble_scale, ) shap_ax.set( ylabel="Feature impacts", xlabel="", ) shap_ax.set_yticklabels([]) plot_interaction_matrix( dataset, component, gridspec=gs[2], ) fig.suptitle(f"Component {component} report", fontsize=12, y=0.95) if show: plt.show() return fig