Source code for stressnet.utils.plotting

"""Matplotlib helpers for visualizing ForSys frames and inferred forces."""

__all__ = ['plot_with_force']

import os
from typing import TYPE_CHECKING, Literal

import cmocean  # noqa: F401  # registers "cmo.*" colormaps with matplotlib
import matplotlib
import numpy as np
from matplotlib import pyplot as plt

if TYPE_CHECKING:
    from forsys.frames import Frame

CBAR_STEP = 0.2
DEFAULT_DISCRETE_CMAP = 'tab20'
DEFAULT_CONTINUOUS_CMAP = 'cmo.matter'


def _normalize_force_to_plot(force_to_plot: Literal['stress', 'tension', 'gt', 'ground-truth'] | None
                             ) -> Literal['stress', 'tension', 'gt', 'ground-truth']:
    if force_to_plot is None:
        return None
    if force_to_plot in ('stress', 'tension'):
        return 'tension'
    if force_to_plot in ('gt', 'ground-truth'):
        return 'gt'
    raise ValueError(f'"{force_to_plot}" not supported, must provide "stress", "tension", "gt", or "ground-truth".')


[docs] def plot_with_force(frame: 'Frame', filename: str | None = None, force_to_plot: Literal['stress', 'tension', 'gt', 'ground-truth'] | None = None, mirror_y: bool = False, figsize: tuple[float, float] = (10, 10), **kwargs ) -> None: """Plot the tissue graph, optionally coloring edges by tension or ground truth. Parameters ---------- frame ForSys frame with vertices/edges populated. filename If set, save the figure to this path (``.png`` appended when extension missing). force_to_plot ``None`` draws discrete colors per internal big edge; ``'stress'`` / ``'tension'`` maps continuous edge ``tension``; ``'gt'`` / ``'ground-truth'`` maps ``gt``. mirror_y If ``True``, invert the y-axis (image-style coordinates). figsize Matplotlib figure size in inches. **kwargs Common keys: ``cmap``, ``cbar``, ``cbar_step``, ``cbar_params``, ``title``, ``plot_kwargs`` (merged into ``plt.plot`` for non-external edges). Returns ------- None Displays interactively when ``filename`` is ``None``; otherwise writes file and closes. """ force_to_plot = _normalize_force_to_plot(force_to_plot) plt.close() _, ax = plt.subplots(figsize=figsize) ax.set_aspect('equal') if filename: new_dir = os.path.dirname(filename) if new_dir: os.makedirs(new_dir, exist_ok=True) cbar_step = kwargs.get('cbar_step', CBAR_STEP) external_edge_ids = {eid for be in frame.big_edges.values() for eid in be.edges if be.external} if force_to_plot is None: # discrete color mapping, one for each big edge colormap = plt.get_cmap(kwargs.get('cmap', DEFAULT_DISCRETE_CMAP)) big_edges_to_color = [be for be in frame.big_edges.values() if not be.external] edge_id_to_color_index = {eid: i for i, be in enumerate(big_edges_to_color) for eid in be.edges} clean_vmax = cbar_norm = clean_vmin = None def _get_edge_color(e): return colormap(edge_id_to_color_index[e.id] % colormap.N) else: colormap = plt.get_cmap(kwargs.get('cmap', DEFAULT_CONTINUOUS_CMAP)) all_forces = [getattr(edge, force_to_plot) for eid, edge in frame.edges.items() if eid not in external_edge_ids] # Calculate the 'clean' floor (e.g., 0.3 becomes 0.2 if CBAR_STEP=0.2) clean_vmin = np.floor(np.min(all_forces) / cbar_step) * cbar_step # Calculate the 'clean' ceiling (e.g., 1.7 becomes 1.8 if CBAR_STEP=0.2) clean_vmax = np.ceil(np.max(all_forces) / cbar_step) * cbar_step cbar_norm = matplotlib.colors.Normalize(vmin=clean_vmin, vmax=clean_vmax) def _get_edge_color(e): val = getattr(e, force_to_plot) return colormap(cbar_norm(val)) for eid, edge in frame.edges.items(): if eid in external_edge_ids: plot_kwargs = {'color': 'black', 'linewidth': 0.5, 'alpha': 0.6} else: plot_kwargs = { 'color': _get_edge_color(edge), 'linewidth': 2, **kwargs.get('plot_kwargs', {}) } plt.plot((edge.v1.x, edge.v2.x), (edge.v1.y, edge.v2.y), **plot_kwargs) plt.axis('off') if mirror_y: plt.gca().invert_yaxis() if kwargs.get('cbar') and cbar_norm is not None: sm = matplotlib.cm.ScalarMappable(cmap=colormap, norm=cbar_norm) sm.set_array([]) default_cbar_params = {'pad': 0.04, 'shrink': 0.7, 'format': '%.1f'} cbar_params = {**default_cbar_params, **kwargs.get('cbar_params', {})} cbar = plt.colorbar(sm, ax=ax, **cbar_params) ticks = np.arange(clean_vmin, clean_vmax + (cbar_step / 2), cbar_step) cbar.set_ticks(ticks) if kwargs.get('cbar_tick_params'): cbar.ax.tick_params(**kwargs['cbar_tick_params']) if kwargs.get('title'): plt.title(kwargs['title']) plt.tight_layout() if filename: filename = filename if filename.split('.')[-1] in ('png', 'pdf', 'svg') else (filename + '.png') plt.savefig(filename, dpi=500, transparent=True) plt.close() else: plt.plot()