Source code for uchrom.pl.structure_3d

"""PyVista-based 3-D rendering of reconstructed chromatin structures.

Draws a :class:`~uchrom.core.ChromData` (or raw coordinate arrays) as
smooth tubes along the genomic chain — matching the look of the
interactive `uchrom.browser` without requiring Qt.

Works in notebooks (inline screenshot or interactive if ``pythreejs`` /
`trame` backends are installed) and in scripts (save to PNG, or show
an OS window).  PyVista is imported lazily so importing ``uchrom.pl``
keeps its zero-cost path when the user only needs matplotlib helpers.

Colour options
--------------
``colour`` accepts:

- A matplotlib-style colour name or ``(r, g, b)`` tuple → whole chain
  painted uniformly.
- ``'bin'`` → gradient along bin index using ``cmap``.
- ``'chrom'`` → one colour per chromosome (requires ``spots`` to carry
  ``chrom`` — automatically true for a :class:`ChromData`).
- ``'trace'`` → one colour per trace (distinct alleles).
- An ``(N,)`` numpy array of scalar values (e.g. epigenetic signal) →
  mapped through ``cmap``.
"""

from __future__ import annotations

from typing import Any, Optional, Sequence, Tuple, Union

import numpy as np


_DEFAULT_CMAP = "viridis"


# ----------------------------------------------------------------------
# Coordinate extraction
# ----------------------------------------------------------------------

def _coords_and_groups(
    cd_or_coords: Any,
    chrom: Optional[str] = None,
    trace_id: Optional[Union[int, str]] = None,
) -> Tuple[list, np.ndarray, Optional[Any]]:
    """Return ``(coord_segments, scalar_per_bin_fallback, spots_df)``.

    ``coord_segments`` is a list of ``(coords, label)`` tuples so that
    each contiguous stretch can be rendered as a separate tube (chain
    jumps between chromosomes / traces become gaps rather than fake
    connecting edges).
    """
    try:
        from uchrom.core import ChromData  # local import
    except Exception:
        ChromData = None  # type: ignore

    if ChromData is not None and isinstance(cd_or_coords, ChromData):
        cd = cd_or_coords
        spots = cd.spots.copy()
        coords = np.asarray(cd.coords)
        if chrom is not None:
            keep = spots["chrom"].astype(str) == str(chrom)
            spots = spots[keep].reset_index(drop=True)
            coords = coords[keep.to_numpy()]
        if trace_id is not None and "trace_id" in spots.columns:
            keep = spots["trace_id"].astype(str) == str(trace_id)
            spots = spots[keep].reset_index(drop=True)
            coords = coords[keep.to_numpy()]

        # Sort by (chrom, trace_id, start) so tubes follow the chain
        order_cols = [c for c in ("chrom", "trace_id", "start")
                       if c in spots.columns]
        if order_cols:
            order = np.argsort(spots[order_cols].apply(
                lambda r: tuple(r), axis=1).tolist())
            # np.argsort on tuples is fussy — fall back to a DataFrame sort
            import pandas as pd
            spots = spots.assign(_o=np.arange(len(spots)))
            spots = spots.sort_values(order_cols, kind="mergesort")
            coords = coords[spots["_o"].to_numpy()]
            spots = spots.drop(columns="_o").reset_index(drop=True)

        # Build segments — split whenever chrom / trace_id changes.
        segments = []
        if spots.empty:
            return segments, np.empty(0), spots
        if "trace_id" in spots.columns and "chrom" in spots.columns:
            key = list(zip(
                spots["chrom"].astype(str).tolist(),
                spots["trace_id"].astype(str).tolist(),
            ))
            start = 0
            for i in range(1, len(key) + 1):
                if i == len(key) or key[i] != key[i - 1]:
                    segments.append(
                        (coords[start:i], {
                            "chrom": key[start][0],
                            "trace_id": key[start][1],
                        })
                    )
                    start = i
        else:
            segments.append((coords, {}))
        return segments, np.arange(len(coords)), spots
    else:
        coords = np.asarray(cd_or_coords, dtype=np.float64)
        if coords.ndim != 2 or coords.shape[1] != 3:
            raise ValueError(
                "Expected a ChromData or a (n_bins, 3) coordinate array"
            )
        return [(coords, {})], np.arange(coords.shape[0]), None


# ----------------------------------------------------------------------
# Colour resolution
# ----------------------------------------------------------------------

def _resolve_color(
    mode: Union[str, Sequence, np.ndarray],
    segment_coords: np.ndarray,
    segment_meta: dict,
    all_bin_idx: np.ndarray,
    spots_df,
    cmap: str,
):
    """Return ``(color_array, per_vertex, label)`` for one chain segment.

    ``color_array`` is either a single RGB triple or a (N, 3) float
    array in [0, 1].  ``per_vertex=True`` when the tube should be
    coloured along its length.
    """
    import matplotlib.colors as mpl_colors
    import matplotlib.cm as mpl_cm

    N = segment_coords.shape[0]
    palette = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00",
               "#a65628", "#f781bf", "#999999", "#66c2a5", "#fc8d62"]

    if isinstance(mode, str):
        m = mode.lower()
        if m == "bin":
            t = np.linspace(0, 1, N) if N > 0 else np.empty(0)
            cm = mpl_cm.get_cmap(cmap)
            return cm(t)[:, :3], True, "bin"
        if m == "chrom":
            chrom = segment_meta.get("chrom", "chr?")
            h = (hash(chrom) % len(palette))
            return mpl_colors.to_rgb(palette[h]), False, f"chrom={chrom}"
        if m == "trace":
            tid = segment_meta.get("trace_id", "")
            h = (hash(str(tid)) % len(palette))
            return (
                mpl_colors.to_rgb(palette[h]), False,
                f"trace={tid}",
            )
        # Named colour
        return mpl_colors.to_rgb(m), False, None

    if isinstance(mode, np.ndarray) and mode.ndim == 1:
        # Index into the full scalar array by segment position
        seg_vals = mode[all_bin_idx[:N]] if len(mode) >= N else mode
        norm = (seg_vals - np.nanmin(mode)) / (
            max(np.nanmax(mode) - np.nanmin(mode), 1e-12)
        )
        cm = mpl_cm.get_cmap(cmap)
        return cm(norm)[:, :3], True, "scalar"

    # (r, g, b) tuple
    return tuple(mode[:3]), False, None


# ----------------------------------------------------------------------
# Public entry
# ----------------------------------------------------------------------

[docs] def plot_structure_3d( cd_or_coords: Any, chrom: Optional[str] = None, trace_id: Optional[Union[int, str]] = None, colour: Union[str, Sequence, np.ndarray] = "bin", cmap: str = _DEFAULT_CMAP, tube_radius: Optional[float] = None, spline_smoothing: bool = True, background: str = "white", show_scalar_bar: Optional[bool] = None, save_png: Optional[str] = None, window_size: Tuple[int, int] = (900, 700), show: bool = True, plotter: Any = None, notebook: Optional[bool] = None, jupyter_backend: str = "static", bond_outlier_factor: Optional[float] = None, ): """Render a ``ChromData`` (or ``(n_bins, 3)`` coords) as a 3-D tube. Parameters ---------- cd_or_coords : ChromData or ndarray (n_bins, 3) The structure to plot. A ``ChromData`` is sorted by ``(chrom, trace_id, start)`` and broken into contiguous sub-chains, each drawn as one tube so discontinuities are not bridged by fake edges. chrom : str, optional Restrict to one chromosome. trace_id : int or str, optional Restrict to one trace / allele. colour : 'bin' | 'chrom' | 'trace' | named colour | ndarray See the module docstring for the full colour contract. cmap : str Matplotlib colour-map name used when ``colour`` is a gradient or a scalar array. tube_radius : float, optional Radius of the rendered tube. If ``None``, auto-picks 1.5 % of the bounding-box diagonal. spline_smoothing : bool If True, interpolate each sub-chain with a Catmull-Rom-style spline before piping (nicer curvature at sharp bends). background : str Colour of the plotter background. show_scalar_bar : bool, optional Whether to render a colour-bar (default: auto based on colour mode). save_png : str, optional If given, write a PNG to this path. window_size : (int, int) Renderer resolution in pixels. show : bool If True, call ``plotter.show()`` at the end (pops a window or renders inline in a notebook). plotter : pyvista.Plotter, optional Add to an existing plotter instead of creating one. notebook : bool, optional If None, auto-detect. jupyter_backend : str Passed to ``plotter.show(jupyter_backend=...)`` — ``'static'`` renders an inline image (safest for nbconvert / CI), ``'trame'`` gives an interactive widget. bond_outlier_factor : float, optional If set, any consecutive-bin bond longer than ``factor * median_bond`` is treated as a discontinuity and the chain is split at that point (rendered as two separate tubes with a visible gap). Default ``None`` keeps the chain continuous — use this when the algorithm produces a legitimate chain that just happens to have a scale mismatch between intra-TAD bonds and inter-TAD boundary bonds (the typical two-stage reconstruction case). Set to e.g. ``50`` to catch only pathologically long "bridges". Returns ------- plotter : pyvista.Plotter The plotter object (useful for adding annotations before calling ``show()``). ``None`` if PyVista is not installed. """ try: import pyvista as pv except ImportError as exc: # pragma: no cover raise ImportError( "pyvista is required for plot_structure_3d — " "install it or use the browser module instead" ) from exc segments, all_bin_idx, spots_df = _coords_and_groups( cd_or_coords, chrom=chrom, trace_id=trace_id, ) if not segments: raise ValueError("No coordinates to render — empty filter?") # Split each contiguous segment further at pathologically long bonds # so we don't draw fake bridges across algorithm-induced discontinuities. if bond_outlier_factor is not None and bond_outlier_factor > 0: all_bonds = [] for seg, _ in segments: if seg.shape[0] >= 2: all_bonds.extend( np.linalg.norm(seg[1:] - seg[:-1], axis=1).tolist() ) if all_bonds: thr = float(bond_outlier_factor * np.median(all_bonds)) refined = [] for seg, meta in segments: if seg.shape[0] < 2: refined.append((seg, meta)) continue d = np.linalg.norm(seg[1:] - seg[:-1], axis=1) cuts = np.where(d > thr)[0] + 1 if cuts.size == 0: refined.append((seg, meta)) else: pts = [0, *cuts.tolist(), seg.shape[0]] for a, b in zip(pts[:-1], pts[1:]): if b - a >= 2: refined.append((seg[a:b], meta)) segments = refined # Auto tube radius from global extent all_xyz = np.vstack([s for s, _ in segments]) extent = float(np.linalg.norm(all_xyz.max(0) - all_xyz.min(0))) if tube_radius is None: tube_radius = max(extent * 0.010, 0.01) if notebook is None: try: from IPython import get_ipython notebook = get_ipython() is not None except Exception: notebook = False if plotter is None: plotter = pv.Plotter( notebook=notebook, off_screen=save_png is not None and not show, window_size=list(window_size), ) plotter.set_background(background) any_scalar = False for segment_coords, meta in segments: if segment_coords.shape[0] < 2: continue col, per_vertex, label = _resolve_color( colour, segment_coords, meta, all_bin_idx, spots_df, cmap, ) if spline_smoothing: n_pts = min(segment_coords.shape[0] * 3, 1500) try: spline = pv.Spline(segment_coords, n_points=n_pts) tube = spline.tube(radius=tube_radius, n_sides=12) except Exception: poly = pv.PolyData() poly.points = segment_coords n_pts = segment_coords.shape[0] cells = np.zeros(n_pts + 1, dtype=np.int_) cells[0] = n_pts cells[1:] = np.arange(n_pts) poly.lines = cells tube = poly.tube(radius=tube_radius, n_sides=12) else: poly = pv.PolyData() poly.points = segment_coords n_pts = segment_coords.shape[0] cells = np.zeros(n_pts + 1, dtype=np.int_) cells[0] = n_pts cells[1:] = np.arange(n_pts) poly.lines = cells tube = poly.tube(radius=tube_radius, n_sides=12) kw = dict(smooth_shading=True, specular=0.3, specular_power=15) if per_vertex: # Map colours onto the tube by bin index along the spline. n_out_pts = tube.points.shape[0] t = np.linspace(0, 1, n_out_pts) scalar_idx = np.linspace(0, segment_coords.shape[0] - 1, n_out_pts) scalar_idx = np.clip(scalar_idx.astype(int), 0, segment_coords.shape[0] - 1) col_arr = np.asarray(col) if col_arr.ndim == 2 and col_arr.shape[0] == segment_coords.shape[0]: tube["rgb"] = (col_arr[scalar_idx] * 255).astype(np.uint8) plotter.add_mesh(tube, scalars="rgb", rgb=True, **kw) any_scalar = True else: plotter.add_mesh(tube, color=tuple(col), **kw) else: plotter.add_mesh(tube, color=tuple(col), **kw) plotter.add_axes() plotter.show_bounds( grid="back", location="outer", ticks="outside", xtitle="x", ytitle="y", ztitle="z", ) if save_png: plotter.screenshot(save_png) if show: if notebook: return plotter.show(jupyter_backend=jupyter_backend) plotter.show() return plotter
__all__ = ["plot_structure_3d"]