Source code for uchrom.recon.fish.gem_fish

"""Joint Hi-C + FISH 3-D reconstruction (GEM-FISH).

Independent PyTorch reimplementation of Abbas et al. 2019
(*Nat. Commun.*, doi:10.1038/s41467-019-10005-6).  Not derived from the
upstream MATLAB source (github.com/ahmedabbas81/GEM-FISH, MIT
licence).

Three-stage pipeline
--------------------
1. **TAD-level coarse model** — optimise the 3-D positions of each TAD
   centre with ``C_g = C_1 + λ_E·C_2 + λ_F·C_3`` (Hi-C KL + polymer
   prior + FISH inter-TAD distance constraints).  See
   :mod:`._tad_level`.
2. **Intra-TAD fine models** — for each TAD, optimise per-bin
   coordinates with ``C_t = C_1 + λ_E·C_2 + λ_R·C_4`` (intra-TAD Hi-C
   KL + polymer prior + FISH radius-of-gyration constraints).  See
   :mod:`._intra_tad`.
3. **Assembly** — translate each intra-TAD cloud so its centroid lies
   at the Stage-1 centre, then rotate each TAD (except the first) so
   its start-anchor points toward the expected neighbouring-TAD
   position.  See :mod:`._assembly`.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd

from ._hic import aggregate_over_tads, load_contact_matrix, contact_to_distance
from ._tad_level import Stage1Params, reconstruct_tad_level
from ._intra_tad import Stage2Params, reconstruct_intra_tad
from ._assembly import AssemblyParams, assemble


[docs] @dataclass class GEMFISHParams: """Runtime parameters for :func:`reconstruct_gem_fish`.""" # Stage weights lambda_E_tad: float = 0.05 lambda_F: float = 0.01 lambda_E_intra: float = 0.05 lambda_R: float = 0.01 # Stage-1 optimisation stage1_iter: int = 1000 stage1_lr: float = 0.05 stage1_ensemble: int = 20 stage1_optimiser: str = "adam" # Stage-2 optimisation stage2_iter: int = 500 stage2_lr: float = 0.05 stage2_ensemble: int = 10 # Hi-C → distance conversion (Eqn. 10) hic_alpha: float = 0.25 hic_fallback_c: float = 1.0 hic_fallback_beta: float = 0.5 # TAD caller passthrough (for the default DI caller) tad_di_window: int = 50 tad_di_smoothing: float = 0.1 tad_di_min_size_frac: float = 0.05 tad_max_size_frac: float = 0.05 """Cap any called TAD at ``tad_max_size_frac * n_bins`` bins. Oversized DI windows (often in heterochromatic regions like the chr21 p-arm where DI signal is weak) are split into equal chunks. Set to 1.0 to disable capping.""" # Shared device: str = "auto" verbose: bool = True
# ---------------------------------------------------------------------- # TAD partitioning (default path: strc.tad.get_domains) # ---------------------------------------------------------------------- def _auto_tads( contacts: np.ndarray, params: GEMFISHParams, ) -> List[Tuple[int, int]]: """Run the Dixon DI TAD caller to partition ``contacts`` into contiguous bin windows ``[(s_0, e_0), (s_1, e_1), ...]``. """ import torch from uchrom.strc.tad import get_domains mat = torch.as_tensor(contacts, dtype=torch.float64) windows = get_domains( mat, smoothing_param=params.tad_di_smoothing, min_size_frac=params.tad_di_min_size_frac, window=params.tad_di_window, device=params.device, ) # get_domains returns end-exclusive (start_idx, end_idx) tuples that # mostly tile the chromosome — the last window may stop a few bins # short. Extend it to cover the full matrix so downstream bin # counts match. out = [(int(s), int(e)) for s, e in windows] if out and out[-1][1] < contacts.shape[0]: out[-1] = (out[-1][0], int(contacts.shape[0])) # Split oversized TADs into equal chunks. The Dixon DI caller # happily lumps low-signal regions (centromeric / heterochromatic # arms) into one huge TAD — those break the later assembly step # because a 500-bin Gaussian-chain cloud has a Rg so large it # can't sit adjacent to a 30-bin neighbour without a pathological # boundary bond. max_size = max(10, int(contacts.shape[0] * params.tad_max_size_frac)) capped = [] for s, e in out: n = e - s if n <= max_size: capped.append((s, e)) continue n_chunks = int(np.ceil(n / max_size)) chunk = int(np.ceil(n / n_chunks)) sub_s = s while sub_s < e: sub_e = min(sub_s + chunk, e) capped.append((sub_s, sub_e)) sub_s = sub_e return capped # ---------------------------------------------------------------------- # FISH aggregation # ---------------------------------------------------------------------- def _tad_fish_distances( fish_cd: Any, chrom: str, tad_bin_indices: List[Tuple[int, int]], bin_df: pd.DataFrame, ) -> np.ndarray: """Aggregate FISH-measured distances onto TAD centres. For each TAD pair ``(i, j)``, return the median observed distance between spots that fall inside TAD i and TAD j across all traces, averaged over traces in ``fish_cd``. NaN if no trace contributes. """ from uchrom.fea.distance import _bin_coord_cube, _pairwise_distance_per_trace if fish_cd is None: return np.full( (len(tad_bin_indices), len(tad_bin_indices)), np.nan, dtype=np.float64, ) df = fish_cd.get_chrom(chrom).to_dataframe() cube, fish_bin_ids, _ = _bin_coord_cube(df, chrom=chrom) # Bin-level distances across all traces D = _pairwise_distance_per_trace(cube) # (n_traces, n_fish_bins, n_fish_bins) # Median over traces (ignoring NaN) — our "population" FISH distance pop = np.nanmedian(D, axis=0) # Map each fish_bin into a TAD index via the Hi-C bin_df start/end, # by overlap of FISH (start, end) with the Hi-C bin_df rows. hic_bin_starts = bin_df["start"].to_numpy() hic_bin_ends = bin_df["end"].to_numpy() fish_to_tad = [] for (fs, fe) in fish_bin_ids: mid = 0.5 * (int(fs) + int(fe)) # find the Hi-C bin that contains this midpoint idx = int(np.searchsorted(hic_bin_ends, mid, side="right")) if idx >= len(hic_bin_starts) or hic_bin_starts[idx] > mid: fish_to_tad.append(-1) continue # which TAD covers this hic bin? tad_of_bin = -1 for ti, (s, e) in enumerate(tad_bin_indices): if s <= idx < e: tad_of_bin = ti break fish_to_tad.append(tad_of_bin) fish_to_tad = np.asarray(fish_to_tad, dtype=np.int64) n_tads = len(tad_bin_indices) sums = np.zeros((n_tads, n_tads), dtype=np.float64) cnts = np.zeros((n_tads, n_tads), dtype=np.int64) for i in range(pop.shape[0]): ti = fish_to_tad[i] if ti < 0: continue for j in range(i + 1, pop.shape[1]): tj = fish_to_tad[j] if tj < 0: continue if not np.isfinite(pop[i, j]): continue sums[ti, tj] += pop[i, j] cnts[ti, tj] += 1 sums[tj, ti] += pop[i, j] cnts[tj, ti] += 1 with np.errstate(invalid="ignore"): F = np.where(cnts > 0, sums / np.maximum(cnts, 1), np.nan) np.fill_diagonal(F, 0.0) return F def _tad_rg_targets( fish_cd: Any, chrom: str, tad_bin_indices: List[Tuple[int, int]], bin_df: pd.DataFrame, ) -> List[Optional[float]]: """Per-TAD FISH-derived target :math:`\\hat{R}_g^2`. For each TAD, we collect the spots (across all traces) that fall inside the TAD's bin range, compute each trace's intra-TAD Rg² (if at least 3 finite spots), and return the median across traces. Returns ``None`` for TADs without enough data. """ from uchrom.fea.distance import _bin_coord_cube if fish_cd is None: return [None] * len(tad_bin_indices) df = fish_cd.get_chrom(chrom).to_dataframe() cube, fish_bin_ids, _ = _bin_coord_cube(df, chrom=chrom) # Map FISH bins → TAD index (same as above) hic_starts = bin_df["start"].to_numpy() hic_ends = bin_df["end"].to_numpy() fish_to_tad = [] for (fs, fe) in fish_bin_ids: mid = 0.5 * (int(fs) + int(fe)) idx = int(np.searchsorted(hic_ends, mid, side="right")) if idx >= len(hic_starts) or hic_starts[idx] > mid: fish_to_tad.append(-1); continue tad_of_bin = -1 for ti, (s, e) in enumerate(tad_bin_indices): if s <= idx < e: tad_of_bin = ti; break fish_to_tad.append(tad_of_bin) fish_to_tad = np.asarray(fish_to_tad, dtype=np.int64) out = [] for ti in range(len(tad_bin_indices)): members = np.where(fish_to_tad == ti)[0] if members.size < 3: out.append(None); continue sub_cube = cube[:, members, :] # (n_traces, n_members, 3) # Per-trace Rg² if at least 3 finite spots rgs = [] for t in range(sub_cube.shape[0]): coords = sub_cube[t] finite = ~np.isnan(coords[:, 0]) if finite.sum() < 3: continue xyz = coords[finite] centroid = xyz.mean(axis=0) rg2 = ((xyz - centroid) ** 2).sum() / xyz.shape[0] rgs.append(rg2) out.append(float(np.median(rgs)) if rgs else None) return out # ---------------------------------------------------------------------- # Main entry # ----------------------------------------------------------------------
[docs] def reconstruct_gem_fish( hic_path: str, chrom: str, resolution: Optional[int] = None, fish_cd: Any = None, tads: Optional[pd.DataFrame] = None, params: Optional[GEMFISHParams] = None, return_intermediate: bool = False, ) -> Any: """End-to-end GEM-FISH reconstruction for a single chromosome. Parameters ---------- hic_path : str ``.cool`` or ``.mcool`` file. chrom : str Chromosome name as stored in the cooler. resolution : int, optional Required for ``.mcool``; ignored for ``.cool``. fish_cd : ChromData, optional FISH data, e.g. from ``ChromData.from_fofct(...)``. Supplies the C_3 (TAD-centre distances) and C_4 (per-TAD Rg²) targets. If ``None``, both FISH terms are dropped and the result degenerates to a Hi-C-only reconstruction. tads : DataFrame, optional Override the TAD partition. Must have columns ``start`` and ``end`` (bp) and be sorted. If omitted, the Dixon DI caller (:func:`uchrom.strc.tad.get_domains`) is run on the Hi-C matrix. params : GEMFISHParams return_intermediate : bool If True, return a dict with every intermediate artifact (Stage-1 coords, per-TAD coords, loss histories). Returns ------- ChromData Reconstructed 3-D model with bin-resolution coordinates. Each Hi-C bin becomes a spot; ``trace_id = 0`` for the whole chain (single consensus model). Per-stage info is stored in ``cd.uns['gem_fish']``. If ``return_intermediate=True``, returns ``(cd, dict)`` with the full artefact bundle. """ from uchrom import ChromData params = params or GEMFISHParams() if params.verbose: print(f"[gem_fish] loading Hi-C for {chrom} …") hic_mat, bin_df = load_contact_matrix( hic_path, chrom, resolution=resolution, normalize=True, ) if params.verbose: print(f"[gem_fish] Hi-C matrix: {hic_mat.shape} " f"nnz={int((hic_mat > 0).sum())}") # --- TAD partition if tads is None: if params.verbose: print("[gem_fish] calling TADs via Dixon DI …") tad_windows = _auto_tads(hic_mat, params) else: # Convert user-provided (start_bp, end_bp) DataFrame to bin indices starts = bin_df["start"].to_numpy() ends = bin_df["end"].to_numpy() tad_windows = [] for _, r in tads.iterrows(): s = int(np.searchsorted(ends, int(r["start"]) + 1, side="left")) e = int(np.searchsorted(starts, int(r["end"]), side="left")) if e <= s: e = s + 1 tad_windows.append((s, e)) if params.verbose: print(f"[gem_fish] {len(tad_windows)} TADs") # --- Stage 1: TAD-level inter_tad_contacts, _ = aggregate_over_tads( hic_mat, bin_df, pd.DataFrame([ {"start": int(bin_df["start"].iloc[s]), "end": int(bin_df["end"].iloc[e - 1])} for (s, e) in tad_windows ]), ) tad_fish_F = _tad_fish_distances(fish_cd, chrom, tad_windows, bin_df) if params.verbose: print("[gem_fish] Stage 1 — TAD-level reconstruction …") s1_params = Stage1Params( lambda_E=params.lambda_E_tad, lambda_F=params.lambda_F, n_iter=params.stage1_iter, lr=params.stage1_lr, n_ensemble=params.stage1_ensemble, optimiser=params.stage1_optimiser, device=params.device, verbose=params.verbose, ) tad_centres, s1_info = reconstruct_tad_level( inter_tad_contacts, tad_fish_F, params=s1_params, ) # --- Stage 2: intra-TAD rg_targets = _tad_rg_targets(fish_cd, chrom, tad_windows, bin_df) if params.verbose: print("[gem_fish] Stage 2 — intra-TAD reconstruction …") s2_params = Stage2Params( lambda_E=params.lambda_E_intra, lambda_R=params.lambda_R, n_iter=params.stage2_iter, lr=params.stage2_lr, n_ensemble=params.stage2_ensemble, device=params.device, verbose=params.verbose, ) intra_coords, s2_info = reconstruct_intra_tad( hic_mat, tad_windows, target_rg_sq_per_tad=rg_targets, params=s2_params, ) # --- Stage 3: assembly if params.verbose: print("[gem_fish] Stage 3 — assembly …") mid_bp = ((bin_df["start"].to_numpy() + bin_df["end"].to_numpy()) / 2) tad_mid_bp = np.array([ 0.5 * (mid_bp[s] + mid_bp[e - 1]) for (s, e) in tad_windows ]) genomic_dist = np.abs(tad_mid_bp[:, None] - tad_mid_bp[None, :]) inter_tad_gap = contact_to_distance( inter_tad_contacts, genomic_distances=genomic_dist, alpha=params.hic_alpha, fallback_c=params.hic_fallback_c, fallback_beta=params.hic_fallback_beta, ) final_coords = assemble( tad_centres, intra_coords, inter_tad_distances=inter_tad_gap, ) if params.verbose: print(f"[gem_fish] done. Output: {final_coords.shape}") # --- Wrap into ChromData spots = pd.DataFrame({ "chrom": bin_df["chrom"].values, "start": bin_df["start"].astype(int).values, "end": bin_df["end"].astype(int).values, "trace_id": 0, }) uns = { "gem_fish": { "n_tads": len(tad_windows), "stage1_final_loss": float(s1_info["final_loss"].min()), "stage1_best_ensemble": int(s1_info["best_ensemble_idx"]), }, "xyz_unit": "reduced", # arbitrary reconstruction units "genome_assembly": None, } cd = ChromData( coords=final_coords, spots=spots, uns=uns, ) if return_intermediate: return cd, { "tad_windows": tad_windows, "tad_centres": tad_centres, "intra_coords": intra_coords, "inter_tad_contacts": inter_tad_contacts, "inter_tad_gap": inter_tad_gap, "tad_fish_F": tad_fish_F, "rg_targets": rg_targets, "stage1_info": s1_info, "stage2_info": s2_info, } return cd
__all__ = ["GEMFISHParams", "reconstruct_gem_fish"]