GEM-FISH — joint Hi-C + FISH 3-D reconstruction¶
:func:uchrom.recon.fish.reconstruct_gem_fish is an independent
PyTorch reimplementation of Abbas et al. 2019 (Nature Communications,
doi:10.1038/s41467-019-10005-6). It solves a three-stage optimisation
that fuses bulk Hi-C contacts with FISH-measured pairwise distances
into a single 3-D chromosome model.
Pipeline (per chromosome):
TAD partition — call the Dixon directionality-index caller on the Hi-C matrix.
Stage 1 — TAD-level gradient descent on ``C_g = C_1 + λ_E·C_2
λ_F·C_3`` (Hi-C KL + polymer prior + FISH inter-TAD distance).
Stage 2 — intra-TAD per-TAD descent on
C_t = C_1 + λ_E·C_2 + λ_R·C_4(intra-TAD Hi-C KL + polymer prior + FISH Rg target).Stage 3 — assembly — Kabsch-style rotation of each intra-TAD cloud around its Stage-1 centre.
The notebook has two parts. The first runs end-to-end on a fully synthetic chain so you can see the pipeline’s mechanics with a known ground truth. The second downloads Bintu et al. 2018 (Science, the follow-up to the paper’s original Wang 2016 FISH) per-trace imaging data for IMR90 chr21:18.6-20.6 Mb, synthesises a matching Hi-C from the population-mean distances, and runs GEM-FISH on it.
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from uchrom.recon.fish._tad_level import (
Stage1Params, reconstruct_tad_level,
)
from uchrom.recon.fish._intra_tad import (
Stage2Params, reconstruct_intra_tad,
)
from uchrom.recon.fish._assembly import assemble
from uchrom.recon.fish._hic import contact_to_distance
Part 1 — Synthetic end-to-end¶
Generate a known polymer chain organised into 5 tight TADs, fabricate
Hi-C contacts from f = 1/(1+d) and FISH distances from the truth
with noise, then run each of the three stages and compare to the
ground truth at every step.
rng = np.random.default_rng(0)
n_tads, bins_per_tad = 5, 20
truth = []
centre = np.zeros(3)
for t in range(n_tads):
walk = np.cumsum(rng.normal(0, 0.3, (bins_per_tad, 3)), axis=0)
truth.append(walk - walk.mean(0) + centre)
centre = centre + 5.0 * (np.array([1.,0,0]) + 0.3 * rng.normal(0, 1, 3))
truth = np.vstack(truth)
N = truth.shape[0]
tad_bins = [(t*bins_per_tad, (t+1)*bins_per_tad) for t in range(n_tads)]
tad_centres_true = np.array([truth[s:e].mean(0) for (s, e) in tad_bins])
D_true = cdist(truth, truth)
contacts = 1.0 / (1.0 + D_true); np.fill_diagonal(contacts, 0)
# Inter-TAD contacts
inter_contacts = np.zeros((n_tads, n_tads))
for i, (si, ei) in enumerate(tad_bins):
for j, (sj, ej) in enumerate(tad_bins):
inter_contacts[i, j] = contacts[si:ei, sj:ej].sum()
# FISH centre distances + noise
F = cdist(tad_centres_true, tad_centres_true)
F = F + rng.normal(0, 0.1, F.shape); np.fill_diagonal(F, 0)
# Rg² targets per TAD (with noise)
rg_targets = []
for (s, e) in tad_bins:
sub = truth[s:e]
rg2 = ((sub - sub.mean(0)) ** 2).sum(1).mean()
rg_targets.append(rg2 * (1 + 0.05 * rng.standard_normal()))
print(f'Ground truth: {N} bins across {n_tads} TADs')
# Stage 1 — TAD-level
tad_centres_recon, s1_info = reconstruct_tad_level(
inter_contacts, F,
params=Stage1Params(
lambda_E=0.05, lambda_F=0.5,
n_iter=500, n_ensemble=20, device='cpu',
),
)
iu = np.triu_indices(n_tads, 1)
corr_s1 = np.corrcoef(cdist(tad_centres_recon, tad_centres_recon)[iu],
cdist(tad_centres_true, tad_centres_true)[iu])[0, 1]
print(f'Stage 1 centre-distance correlation: {corr_s1:.3f}')
# Stage 2 — intra-TAD
intra_coords, s2_info = reconstruct_intra_tad(
contacts, tad_bins,
target_rg_sq_per_tad=rg_targets,
params=Stage2Params(
lambda_E=0.05, lambda_R=0.01,
n_iter=300, n_ensemble=6, device='cpu',
),
)
per_tad_corr = []
for i, (s, e) in enumerate(tad_bins):
Di = cdist(intra_coords[i], intra_coords[i])
Dt = cdist(truth[s:e], truth[s:e])
n = e - s
ci = np.corrcoef(Di[np.triu_indices(n, 1)],
Dt[np.triu_indices(n, 1)])[0, 1]
per_tad_corr.append(ci)
print(f'Stage 2 mean intra-TAD distance corr: {np.mean(per_tad_corr):.3f}')
# Stage 3 — assembly
tad_mid = np.array([(s + e - 1) / 2 * bins_per_tad for (s, e) in tad_bins])
genomic = np.abs(tad_mid[:, None] - tad_mid[None, :]).astype(np.float64)
gap = contact_to_distance(
inter_contacts, genomic_distances=genomic,
alpha=0.25, fallback_c=1.0, fallback_beta=0.5,
)
final = assemble(tad_centres_recon, intra_coords, inter_tad_distances=gap)
corr_final = np.corrcoef(cdist(final, final)[np.triu_indices(N, 1)],
D_true[np.triu_indices(N, 1)])[0, 1]
print(f'Final chain distance correlation: {corr_final:.3f}')
# Visualise truth vs reconstruction using uchrom.pl.plot_structure_3d.
# Use PyVista subplot so both tubes render side-by-side in one figure.
import pyvista as pv
from uchrom.pl import plot_structure_3d
# TAD-membership scalar so each TAD is coloured distinctly
tad_scalar = np.concatenate([
np.full(e - s, t, dtype=np.float64) for t, (s, e) in enumerate(tad_bins)
])
plotter = pv.Plotter(shape=(1, 2), notebook=True, window_size=(1100, 500))
plotter.subplot(0, 0)
plotter.add_text('Ground truth', font_size=10)
plot_structure_3d(
truth, colour=tad_scalar, cmap='viridis',
plotter=plotter, show=False,
)
plotter.subplot(0, 1)
plotter.add_text(
f'GEM-FISH reconstruction (chain corr = {corr_final:.2f})', font_size=10,
)
plot_structure_3d(
final, colour=tad_scalar, cmap='viridis',
plotter=plotter, show=False,
)
plotter.show(jupyter_backend='static')
# Distance-matrix comparison
fig, axes = plt.subplots(1, 3, figsize=(13, 4))
vmax = np.percentile(D_true, 98)
for ax, mat, title in zip(
axes,
[D_true, cdist(final, final), np.abs(D_true - cdist(final, final))],
['Truth', 'Reconstruction', '|Truth − Recon|'],
):
im = ax.imshow(mat, cmap='viridis_r', origin='lower',
vmax=vmax if title != '|Truth − Recon|' else vmax/3)
ax.set_title(title); ax.set_xlabel('Bin'); ax.set_ylabel('Bin')
plt.colorbar(im, ax=ax, fraction=0.046)
plt.tight_layout()
plt.show()
Part 2 — Paper-faithful IMR90 chr21 reconstruction¶
Load the real Bintu 2018 FISH data (auto-downloaded) and the real
Rao 2014 IMR90 Hi-C (small 30-kb .cool shipped under example-data/,
derived from 4DN accession 4DNFI4QQPDMR),
then run the one-shot reconstruct_gem_fish entry point.
from pathlib import Path
import urllib.request
from uchrom.io import read_bintu_tracing
DATA_URL = (
'https://raw.githubusercontent.com/BogdanBintu/ChromatinImaging/'
'master/Data/IMR90_chr21-18-20Mb.csv'
)
def _repo_root():
for p in [Path.cwd(), *Path.cwd().parents]:
if (p / 'pyproject.toml').exists():
return p
return Path.cwd()
# Auto-download Bintu 2018 (~2 MB) if absent
root = _repo_root()
bintu_csv = root / 'example-data' / 'IMR90_chr21-18-20Mb.csv'
if not bintu_csv.exists():
print(f'[data] downloading Bintu 2018 (~2 MB) → {bintu_csv}')
urllib.request.urlretrieve(DATA_URL, bintu_csv)
# Load into ChromData via the Bintu reader
fish_cd = read_bintu_tracing(
str(bintu_csv), chrom='chr21',
start_bp=18_627_714, # hg38, from the Bintu repo README
resolution_bp=30_000,
)
print(fish_cd)
print(f' n_traces: {fish_cd.traces.shape[0] if fish_cd.traces is not None else fish_cd.spots["trace_id"].nunique()}')
# Hi-C .cool is shipped in the repo (0.3 MB, chr21 only, 30 kb)
hic_cool = root / 'example-data' / 'IMR90_chr21_30kb.cool'
print(f'Hi-C: {hic_cool} ({hic_cool.stat().st_size / 1e6:.2f} MB)')
from uchrom.recon.fish import reconstruct_gem_fish, GEMFISHParams
cd, intermediate = reconstruct_gem_fish(
hic_path=str(hic_cool),
chrom='chr21',
resolution=None, # single-res .cool
fish_cd=fish_cd,
params=GEMFISHParams(
lambda_E_tad=0.5, lambda_F=0.5,
lambda_E_intra=0.05, lambda_R=0.01,
stage1_iter=500, stage1_ensemble=20,
stage2_iter=300, stage2_ensemble=6,
tad_di_window=50, tad_di_min_size_frac=0.01,
tad_max_size_frac=0.05, # cap huge DI-called TADs (p-arm)
device='auto', verbose=False,
),
return_intermediate=True,
)
print(f'Reconstructed: {cd}')
print(f' n_tads called: {cd.uns["gem_fish"]["n_tads"]}')
print(f' final Stage-1 loss: {cd.uns["gem_fish"]["stage1_final_loss"]:.2f}')
# Evaluate: how well does the reconstruction match the FISH
# population-mean pairwise-distance matrix?
from uchrom.fea.distance import _bin_coord_cube, _pairwise_distance_per_trace
df_fish = fish_cd.to_dataframe()
cube, fish_bin_ids, _ = _bin_coord_cube(df_fish, chrom='chr21')
D_fish = np.nanmedian(_pairwise_distance_per_trace(cube), axis=0)
# Map each FISH bin to its corresponding Hi-C bin
bin_starts = cd.spots['start'].to_numpy()
fish_starts = np.array([s for s, _ in fish_bin_ids])
idx = np.array([int(np.argmin(np.abs(bin_starts - fs))) for fs in fish_starts])
D_recon = cdist(cd.coords[idx], cd.coords[idx])
iu_r = np.triu_indices(D_fish.shape[0], 1)
finite = np.isfinite(D_fish[iu_r]) & np.isfinite(D_recon[iu_r])
corr_real = np.corrcoef(D_fish[iu_r][finite], D_recon[iu_r][finite])[0, 1]
print(f'Reconstruction vs FISH pop-mean correlation: {corr_real:.3f}')
print(f' evaluated over {finite.sum()} bin pairs in the Bintu imaged region')
# Visualise. Scale the recon distance matrix back to nm via the mean
# ratio with FISH so the two heat-maps are directly comparable.
scale_est = np.nanmean(D_fish[iu_r][finite]) / np.mean(D_recon[iu_r][finite])
fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
im0 = axes[0].imshow(D_fish, cmap='viridis_r', origin='lower')
axes[0].set_title('FISH population median (nm)')
plt.colorbar(im0, ax=axes[0])
im1 = axes[1].imshow(D_recon * scale_est, cmap='viridis_r', origin='lower',
vmax=np.nanpercentile(D_fish, 98))
axes[1].set_title(f'GEM-FISH reconstruction (corr={corr_real:.2f})')
plt.colorbar(im1, ax=axes[1])
for ax in axes:
ax.set_xlabel('Bin'); ax.set_ylabel('Bin')
plt.tight_layout()
plt.show()
3-D rendering of the reconstructed chain¶
:func:uchrom.pl.plot_structure_3d is a PyVista-based renderer
(same stack the uchrom.browser uses, without the Qt window) that
draws the reconstructed chain as a smooth tube coloured along the
genomic direction.
from uchrom.pl import plot_structure_3d
plot_structure_3d(
cd, chrom='chr21',
colour='bin', cmap='plasma',
window_size=(1000, 800),
notebook=True, jupyter_backend='static',
)
Hi-C contact map vs reconstructed distance matrix¶
Sanity check: the reconstructed 3-D structure should reproduce the input Hi-C contact pattern — regions with high contact frequency should be close in 3D. We compare side-by-side:
Input Hi-C contact map
log10(1 + f_ij)— raw balanced counts from the Rao 2014 IMR90.cool.Reconstructed pairwise distance matrix — Euclidean distance between every pair of reconstructed bins.
If the reconstruction is self-consistent with the input, these should show the inverse pattern: bright diagonal bands and loops in the contact map correspond to dark (short-distance) blocks in the distance matrix.
We also compute the Spearman correlation between
log10(1 + contact) and 1 / distance over all off-diagonal
pairs — the paper reports ~0.6–0.8 on chr21.
from uchrom.recon.fish._hic import load_contact_matrix
from scipy.stats import spearmanr
from scipy.spatial.distance import cdist
hic_mat, hic_bin_df = load_contact_matrix(
str(hic_cool), chrom='chr21', resolution=None, normalize=True,
)
hic_mat = np.nan_to_num(hic_mat, nan=0.0)
D_recon_full = cdist(cd.coords, cd.coords)
n_bins = hic_mat.shape[0]
# Observed / expected: divide each off-diagonal element by the mean at
# its genomic separation. Reveals structural features independent of
# the distance-decay trend.
def oe_normalize(mat):
n = mat.shape[0]
out = np.zeros_like(mat)
for k in range(1, n):
diag_vals = np.array([mat[i, i + k] for i in range(n - k)])
pos = diag_vals[diag_vals > 0]
if pos.size == 0: continue
mu = pos.mean()
for i in range(n - k):
out[i, i + k] = out[i + k, i] = (
mat[i, i + k] / mu if mat[i, i + k] > 0 else 0.0
)
return out
hic_oe = oe_normalize(hic_mat)
dist_oe = oe_normalize(D_recon_full) # high distance / expected → far
# For distance matrix, "observed/expected > 1" means farther than
# expected; to make it analogous to contact-frequency we flip the sign.
dist_inv_oe = np.where(dist_oe > 0, 1.0 / dist_oe, 0.0)
fig, axes = plt.subplots(2, 2, figsize=(10, 9))
im00 = axes[0, 0].imshow(np.log10(hic_mat + 1), cmap='Reds', origin='lower')
axes[0, 0].set_title('Input Hi-C (log10 balanced)')
plt.colorbar(im00, ax=axes[0, 0])
im01 = axes[0, 1].imshow(D_recon_full, cmap='viridis_r', origin='lower',
vmax=np.percentile(D_recon_full, 98))
axes[0, 1].set_title('Reconstructed distance (recon units)')
plt.colorbar(im01, ax=axes[0, 1])
im10 = axes[1, 0].imshow(np.log10(hic_oe + 1e-2), cmap='RdBu_r', origin='lower',
vmin=-1, vmax=1)
axes[1, 0].set_title('Hi-C O/E (log10)')
plt.colorbar(im10, ax=axes[1, 0])
im11 = axes[1, 1].imshow(np.log10(dist_inv_oe + 1e-2), cmap='RdBu_r',
origin='lower', vmin=-1, vmax=1)
axes[1, 1].set_title('Reconstructed 1/(distance O/E) (log10)')
plt.colorbar(im11, ax=axes[1, 1])
for ax in axes.flat:
ax.set_xlabel('Bin'); ax.set_ylabel('Bin')
plt.tight_layout()
plt.show()
# Correlations
iu = np.triu_indices(n_bins, k=1)
ok = (hic_mat[iu] > 0) & (D_recon_full[iu] > 0)
rho_raw, _ = spearmanr(hic_mat[iu][ok], 1.0 / D_recon_full[iu][ok])
ok_oe = (hic_oe[iu] > 0) & (dist_oe[iu] > 0)
rho_oe, _ = spearmanr(hic_oe[iu][ok_oe], 1.0 / dist_oe[iu][ok_oe])
print(f'Spearman(Hi-C raw contact, 1 / recon distance): {rho_raw:.3f}')
print(f'Spearman(Hi-C O/E, 1 / recon distance O/E): {rho_oe:.3f}')
print(' (raw captures mostly the distance-decay trend;')
print(' O/E-corrected shows the structural agreement beyond it)')
Scaling up to other chromosomes¶
IMR90_chr21_30kb.cool only contains chr21. For other chromosomes
(or a different resolution) download the full Rao 2014 IMR90 .mcool
from 4DN (~810 MB, public S3):
https://4dn-open-data-public.s3.amazonaws.com/fourfront-webprod/wfoutput/
1c856462-1e76-4850-bfa6-defcf1524d42/4DNFI4QQPDMR.mcool
Then:
cd = reconstruct_gem_fish(
hic_path='4DNFI4QQPDMR.mcool',
chrom='chr20',
resolution=30_000, # required for .mcool
fish_cd=fish_cd, # swap in FISH for that chromosome
params=GEMFISHParams(device='auto'),
)
cd.write('imr90_chr20_gemfish.h5cd')
Notes and caveats¶
λ weights in the paper (λ_E = 5 × 10¹², λ_F = 1 × 10⁻⁸) are calibrated for raw Hi-C counts without row-normalisation. The PyTorch
C_1here row-normalises inside the loss, so the effective weights differ — our defaults (λ_E = 0.05, λ_F = 0.01) balance the three terms for unit-scale inputs.Single model vs ensemble:
reconstruct_tad_levelreturns the single best model from ann_ensemblebatch by default. Passreturn_all=Trueto get all ensemble members for downstream clustering / averaging.Assembly here uses a single-anchor Kabsch per TAD, which is a simplification of the paper’s gradient-descent assembly. Good enough for the tutorial; for production use you may want to augment the initial rotation with a short local refinement step.
FISH coverage: the Bintu 2018 CSV only covers chr21:18.6–20.6 Mb (65 bins out of the 1 557-bin chromosome). The FISH C_3 / C_4 terms influence only the TADs that fall inside that region; the rest of the chromosome is shaped by Hi-C alone, which is the expected behaviour when FISH coverage is partial.