"""Joint Hi-C + FISH 3-D reconstruction (GEM-FISH).
Independent PyTorch reimplementation of Abbas et al. 2019
(*Nat. Commun.*, doi:10.1038/s41467-019-10005-6). Not derived from the
upstream MATLAB source (github.com/ahmedabbas81/GEM-FISH, MIT
licence).
Three-stage pipeline
--------------------
1. **TAD-level coarse model** — optimise the 3-D positions of each TAD
centre with ``C_g = C_1 + λ_E·C_2 + λ_F·C_3`` (Hi-C KL + polymer
prior + FISH inter-TAD distance constraints). See
:mod:`._tad_level`.
2. **Intra-TAD fine models** — for each TAD, optimise per-bin
coordinates with ``C_t = C_1 + λ_E·C_2 + λ_R·C_4`` (intra-TAD Hi-C
KL + polymer prior + FISH radius-of-gyration constraints). See
:mod:`._intra_tad`.
3. **Assembly** — translate each intra-TAD cloud so its centroid lies
at the Stage-1 centre, then rotate each TAD (except the first) so
its start-anchor points toward the expected neighbouring-TAD
position. See :mod:`._assembly`.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
import pandas as pd
from ._hic import aggregate_over_tads, load_contact_matrix, contact_to_distance
from ._tad_level import Stage1Params, reconstruct_tad_level
from ._intra_tad import Stage2Params, reconstruct_intra_tad
from ._assembly import AssemblyParams, assemble
[docs]
@dataclass
class GEMFISHParams:
"""Runtime parameters for :func:`reconstruct_gem_fish`."""
# Stage weights
lambda_E_tad: float = 0.05
lambda_F: float = 0.01
lambda_E_intra: float = 0.05
lambda_R: float = 0.01
# Stage-1 optimisation
stage1_iter: int = 1000
stage1_lr: float = 0.05
stage1_ensemble: int = 20
stage1_optimiser: str = "adam"
# Stage-2 optimisation
stage2_iter: int = 500
stage2_lr: float = 0.05
stage2_ensemble: int = 10
# Hi-C → distance conversion (Eqn. 10)
hic_alpha: float = 0.25
hic_fallback_c: float = 1.0
hic_fallback_beta: float = 0.5
# TAD caller passthrough (for the default DI caller)
tad_di_window: int = 50
tad_di_smoothing: float = 0.1
tad_di_min_size_frac: float = 0.05
tad_max_size_frac: float = 0.05
"""Cap any called TAD at ``tad_max_size_frac * n_bins`` bins.
Oversized DI windows (often in heterochromatic regions like the
chr21 p-arm where DI signal is weak) are split into equal chunks.
Set to 1.0 to disable capping."""
# Shared
device: str = "auto"
verbose: bool = True
# ----------------------------------------------------------------------
# TAD partitioning (default path: strc.tad.get_domains)
# ----------------------------------------------------------------------
def _auto_tads(
contacts: np.ndarray,
params: GEMFISHParams,
) -> List[Tuple[int, int]]:
"""Run the Dixon DI TAD caller to partition ``contacts`` into
contiguous bin windows ``[(s_0, e_0), (s_1, e_1), ...]``.
"""
import torch
from uchrom.strc.tad import get_domains
mat = torch.as_tensor(contacts, dtype=torch.float64)
windows = get_domains(
mat,
smoothing_param=params.tad_di_smoothing,
min_size_frac=params.tad_di_min_size_frac,
window=params.tad_di_window,
device=params.device,
)
# get_domains returns end-exclusive (start_idx, end_idx) tuples that
# mostly tile the chromosome — the last window may stop a few bins
# short. Extend it to cover the full matrix so downstream bin
# counts match.
out = [(int(s), int(e)) for s, e in windows]
if out and out[-1][1] < contacts.shape[0]:
out[-1] = (out[-1][0], int(contacts.shape[0]))
# Split oversized TADs into equal chunks. The Dixon DI caller
# happily lumps low-signal regions (centromeric / heterochromatic
# arms) into one huge TAD — those break the later assembly step
# because a 500-bin Gaussian-chain cloud has a Rg so large it
# can't sit adjacent to a 30-bin neighbour without a pathological
# boundary bond.
max_size = max(10, int(contacts.shape[0] * params.tad_max_size_frac))
capped = []
for s, e in out:
n = e - s
if n <= max_size:
capped.append((s, e))
continue
n_chunks = int(np.ceil(n / max_size))
chunk = int(np.ceil(n / n_chunks))
sub_s = s
while sub_s < e:
sub_e = min(sub_s + chunk, e)
capped.append((sub_s, sub_e))
sub_s = sub_e
return capped
# ----------------------------------------------------------------------
# FISH aggregation
# ----------------------------------------------------------------------
def _tad_fish_distances(
fish_cd: Any,
chrom: str,
tad_bin_indices: List[Tuple[int, int]],
bin_df: pd.DataFrame,
) -> np.ndarray:
"""Aggregate FISH-measured distances onto TAD centres.
For each TAD pair ``(i, j)``, return the median observed distance
between spots that fall inside TAD i and TAD j across all traces,
averaged over traces in ``fish_cd``. NaN if no trace contributes.
"""
from uchrom.fea.distance import _bin_coord_cube, _pairwise_distance_per_trace
if fish_cd is None:
return np.full(
(len(tad_bin_indices), len(tad_bin_indices)), np.nan,
dtype=np.float64,
)
df = fish_cd.get_chrom(chrom).to_dataframe()
cube, fish_bin_ids, _ = _bin_coord_cube(df, chrom=chrom)
# Bin-level distances across all traces
D = _pairwise_distance_per_trace(cube) # (n_traces, n_fish_bins, n_fish_bins)
# Median over traces (ignoring NaN) — our "population" FISH distance
pop = np.nanmedian(D, axis=0)
# Map each fish_bin into a TAD index via the Hi-C bin_df start/end,
# by overlap of FISH (start, end) with the Hi-C bin_df rows.
hic_bin_starts = bin_df["start"].to_numpy()
hic_bin_ends = bin_df["end"].to_numpy()
fish_to_tad = []
for (fs, fe) in fish_bin_ids:
mid = 0.5 * (int(fs) + int(fe))
# find the Hi-C bin that contains this midpoint
idx = int(np.searchsorted(hic_bin_ends, mid, side="right"))
if idx >= len(hic_bin_starts) or hic_bin_starts[idx] > mid:
fish_to_tad.append(-1)
continue
# which TAD covers this hic bin?
tad_of_bin = -1
for ti, (s, e) in enumerate(tad_bin_indices):
if s <= idx < e:
tad_of_bin = ti
break
fish_to_tad.append(tad_of_bin)
fish_to_tad = np.asarray(fish_to_tad, dtype=np.int64)
n_tads = len(tad_bin_indices)
sums = np.zeros((n_tads, n_tads), dtype=np.float64)
cnts = np.zeros((n_tads, n_tads), dtype=np.int64)
for i in range(pop.shape[0]):
ti = fish_to_tad[i]
if ti < 0:
continue
for j in range(i + 1, pop.shape[1]):
tj = fish_to_tad[j]
if tj < 0:
continue
if not np.isfinite(pop[i, j]):
continue
sums[ti, tj] += pop[i, j]
cnts[ti, tj] += 1
sums[tj, ti] += pop[i, j]
cnts[tj, ti] += 1
with np.errstate(invalid="ignore"):
F = np.where(cnts > 0, sums / np.maximum(cnts, 1), np.nan)
np.fill_diagonal(F, 0.0)
return F
def _tad_rg_targets(
fish_cd: Any,
chrom: str,
tad_bin_indices: List[Tuple[int, int]],
bin_df: pd.DataFrame,
) -> List[Optional[float]]:
"""Per-TAD FISH-derived target :math:`\\hat{R}_g^2`.
For each TAD, we collect the spots (across all traces) that fall
inside the TAD's bin range, compute each trace's intra-TAD Rg² (if
at least 3 finite spots), and return the median across traces.
Returns ``None`` for TADs without enough data.
"""
from uchrom.fea.distance import _bin_coord_cube
if fish_cd is None:
return [None] * len(tad_bin_indices)
df = fish_cd.get_chrom(chrom).to_dataframe()
cube, fish_bin_ids, _ = _bin_coord_cube(df, chrom=chrom)
# Map FISH bins → TAD index (same as above)
hic_starts = bin_df["start"].to_numpy()
hic_ends = bin_df["end"].to_numpy()
fish_to_tad = []
for (fs, fe) in fish_bin_ids:
mid = 0.5 * (int(fs) + int(fe))
idx = int(np.searchsorted(hic_ends, mid, side="right"))
if idx >= len(hic_starts) or hic_starts[idx] > mid:
fish_to_tad.append(-1); continue
tad_of_bin = -1
for ti, (s, e) in enumerate(tad_bin_indices):
if s <= idx < e:
tad_of_bin = ti; break
fish_to_tad.append(tad_of_bin)
fish_to_tad = np.asarray(fish_to_tad, dtype=np.int64)
out = []
for ti in range(len(tad_bin_indices)):
members = np.where(fish_to_tad == ti)[0]
if members.size < 3:
out.append(None); continue
sub_cube = cube[:, members, :] # (n_traces, n_members, 3)
# Per-trace Rg² if at least 3 finite spots
rgs = []
for t in range(sub_cube.shape[0]):
coords = sub_cube[t]
finite = ~np.isnan(coords[:, 0])
if finite.sum() < 3:
continue
xyz = coords[finite]
centroid = xyz.mean(axis=0)
rg2 = ((xyz - centroid) ** 2).sum() / xyz.shape[0]
rgs.append(rg2)
out.append(float(np.median(rgs)) if rgs else None)
return out
# ----------------------------------------------------------------------
# Main entry
# ----------------------------------------------------------------------
[docs]
def reconstruct_gem_fish(
hic_path: str,
chrom: str,
resolution: Optional[int] = None,
fish_cd: Any = None,
tads: Optional[pd.DataFrame] = None,
params: Optional[GEMFISHParams] = None,
return_intermediate: bool = False,
) -> Any:
"""End-to-end GEM-FISH reconstruction for a single chromosome.
Parameters
----------
hic_path : str
``.cool`` or ``.mcool`` file.
chrom : str
Chromosome name as stored in the cooler.
resolution : int, optional
Required for ``.mcool``; ignored for ``.cool``.
fish_cd : ChromData, optional
FISH data, e.g. from ``ChromData.from_fofct(...)``. Supplies
the C_3 (TAD-centre distances) and C_4 (per-TAD Rg²) targets.
If ``None``, both FISH terms are dropped and the result
degenerates to a Hi-C-only reconstruction.
tads : DataFrame, optional
Override the TAD partition. Must have columns ``start`` and
``end`` (bp) and be sorted. If omitted, the Dixon DI caller
(:func:`uchrom.strc.tad.get_domains`) is run on the Hi-C matrix.
params : GEMFISHParams
return_intermediate : bool
If True, return a dict with every intermediate artifact
(Stage-1 coords, per-TAD coords, loss histories).
Returns
-------
ChromData
Reconstructed 3-D model with bin-resolution coordinates. Each
Hi-C bin becomes a spot; ``trace_id = 0`` for the whole chain
(single consensus model). Per-stage info is stored in
``cd.uns['gem_fish']``. If ``return_intermediate=True``,
returns ``(cd, dict)`` with the full artefact bundle.
"""
from uchrom import ChromData
params = params or GEMFISHParams()
if params.verbose:
print(f"[gem_fish] loading Hi-C for {chrom} …")
hic_mat, bin_df = load_contact_matrix(
hic_path, chrom, resolution=resolution, normalize=True,
)
if params.verbose:
print(f"[gem_fish] Hi-C matrix: {hic_mat.shape} "
f"nnz={int((hic_mat > 0).sum())}")
# --- TAD partition
if tads is None:
if params.verbose:
print("[gem_fish] calling TADs via Dixon DI …")
tad_windows = _auto_tads(hic_mat, params)
else:
# Convert user-provided (start_bp, end_bp) DataFrame to bin indices
starts = bin_df["start"].to_numpy()
ends = bin_df["end"].to_numpy()
tad_windows = []
for _, r in tads.iterrows():
s = int(np.searchsorted(ends, int(r["start"]) + 1, side="left"))
e = int(np.searchsorted(starts, int(r["end"]), side="left"))
if e <= s:
e = s + 1
tad_windows.append((s, e))
if params.verbose:
print(f"[gem_fish] {len(tad_windows)} TADs")
# --- Stage 1: TAD-level
inter_tad_contacts, _ = aggregate_over_tads(
hic_mat, bin_df,
pd.DataFrame([
{"start": int(bin_df["start"].iloc[s]),
"end": int(bin_df["end"].iloc[e - 1])}
for (s, e) in tad_windows
]),
)
tad_fish_F = _tad_fish_distances(fish_cd, chrom, tad_windows, bin_df)
if params.verbose:
print("[gem_fish] Stage 1 — TAD-level reconstruction …")
s1_params = Stage1Params(
lambda_E=params.lambda_E_tad,
lambda_F=params.lambda_F,
n_iter=params.stage1_iter,
lr=params.stage1_lr,
n_ensemble=params.stage1_ensemble,
optimiser=params.stage1_optimiser,
device=params.device,
verbose=params.verbose,
)
tad_centres, s1_info = reconstruct_tad_level(
inter_tad_contacts, tad_fish_F, params=s1_params,
)
# --- Stage 2: intra-TAD
rg_targets = _tad_rg_targets(fish_cd, chrom, tad_windows, bin_df)
if params.verbose:
print("[gem_fish] Stage 2 — intra-TAD reconstruction …")
s2_params = Stage2Params(
lambda_E=params.lambda_E_intra,
lambda_R=params.lambda_R,
n_iter=params.stage2_iter,
lr=params.stage2_lr,
n_ensemble=params.stage2_ensemble,
device=params.device,
verbose=params.verbose,
)
intra_coords, s2_info = reconstruct_intra_tad(
hic_mat, tad_windows,
target_rg_sq_per_tad=rg_targets,
params=s2_params,
)
# --- Stage 3: assembly
if params.verbose:
print("[gem_fish] Stage 3 — assembly …")
mid_bp = ((bin_df["start"].to_numpy() + bin_df["end"].to_numpy()) / 2)
tad_mid_bp = np.array([
0.5 * (mid_bp[s] + mid_bp[e - 1]) for (s, e) in tad_windows
])
genomic_dist = np.abs(tad_mid_bp[:, None] - tad_mid_bp[None, :])
inter_tad_gap = contact_to_distance(
inter_tad_contacts, genomic_distances=genomic_dist,
alpha=params.hic_alpha,
fallback_c=params.hic_fallback_c,
fallback_beta=params.hic_fallback_beta,
)
final_coords = assemble(
tad_centres, intra_coords, inter_tad_distances=inter_tad_gap,
)
if params.verbose:
print(f"[gem_fish] done. Output: {final_coords.shape}")
# --- Wrap into ChromData
spots = pd.DataFrame({
"chrom": bin_df["chrom"].values,
"start": bin_df["start"].astype(int).values,
"end": bin_df["end"].astype(int).values,
"trace_id": 0,
})
uns = {
"gem_fish": {
"n_tads": len(tad_windows),
"stage1_final_loss": float(s1_info["final_loss"].min()),
"stage1_best_ensemble": int(s1_info["best_ensemble_idx"]),
},
"xyz_unit": "reduced", # arbitrary reconstruction units
"genome_assembly": None,
}
cd = ChromData(
coords=final_coords,
spots=spots,
uns=uns,
)
if return_intermediate:
return cd, {
"tad_windows": tad_windows,
"tad_centres": tad_centres,
"intra_coords": intra_coords,
"inter_tad_contacts": inter_tad_contacts,
"inter_tad_gap": inter_tad_gap,
"tad_fish_F": tad_fish_F,
"rg_targets": rg_targets,
"stage1_info": s1_info,
"stage2_info": s2_info,
}
return cd
__all__ = ["GEMFISHParams", "reconstruct_gem_fish"]