"""Stage 1 of GEM-FISH — TAD-level coarse 3-D model.
Minimises the objective of Abbas et al. 2019 Eqn. (2)::
C_g = C_1 + lambda_E * C_2 + lambda_F * C_3
- ``C_1``: row-wise KL between Hi-C and reconstructed inverse distances
(:func:`._losses.hic_kl_loss`).
- ``C_2``: polymer energy (bond + angle + excluded volume)
(:func:`._losses.polymer_energy`).
- ``C_3``: squared error against FISH-measured TAD-centre distances
(:func:`._losses.fish_distance_loss`).
Gradient descent is run on an *ensemble* of ``n_ensemble`` independent
initialisations; the best single model (lowest final loss) is returned
by default but the full ensemble is also available via ``return_all``.
Loss weights in the paper (λ_E = 5e12, λ_F = 1e-8) assume raw (not
row-normalised) Hi-C counts as input; our PyTorch ``C_1`` is scale-free
after row-normalisation, so we expose smaller defaults that balance the
three terms for unit-normalised data — users should tune them for their
data.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
from torch import Tensor
from ._losses import (
hic_kl_loss,
polymer_energy,
fish_distance_loss,
)
[docs]
@dataclass
class Stage1Params:
"""Optimisation parameters for Stage-1 (TAD-level) reconstruction."""
lambda_E: float = 0.05 # polymer weight (relative to KL)
lambda_F: float = 0.01 # FISH distance weight
bond_length: float = 1.0
bond_weight: float = 1.0
angle_weight: float = 0.1
exclude_weight: float = 1.0
exclude_radius: float = 0.9
n_iter: int = 1000
lr: float = 0.05
optimiser: str = "adam" # 'adam' | 'lbfgs'
n_ensemble: int = 20
init_scale: float = 3.0 # std of Gaussian initialisation
init_from_mds: bool = True
"""If True, seed the Gaussian-chain optimisation with a classical
MDS embedding of the FISH distance matrix (filled in with Hi-C-
derived distances where FISH is absent). Starting from an MDS
solution that already respects the pairwise-distance structure
gives the gradient descent a much better basin than N(0, σ²) —
crucial on chromosomes with partial FISH coverage where the
un-anchored centres otherwise drift into extended configurations.
Each ensemble replica gets a small random perturbation on top of
the MDS seed so the ensemble still has diversity."""
device: str = "auto" # 'auto' | 'cuda' | 'mps' | 'cpu'
verbose: bool = False
# ----------------------------------------------------------------------
def _mds_seed(
fish_distances: np.ndarray,
contacts: np.ndarray,
) -> np.ndarray:
"""Classical-MDS 3-D embedding from a (mostly) pairwise distance matrix.
Missing FISH entries (NaN or ≤ 0) are filled in with the median of
available FISH distances for the same genomic separation, falling
back to a global median if no same-separation data is available.
The resulting dense distance matrix is centred via the double-
centring identity ``B = -0.5 J D² J``, eigendecomposed, and the
top-3 eigenvectors scaled by √eigenvalues give the 3-D coords.
When every FISH entry is missing or non-positive, returns a small
random point cloud so downstream code never sees NaNs.
"""
D = np.asarray(fish_distances, dtype=np.float64).copy()
n = D.shape[0]
if n < 2:
return np.zeros((max(n, 1), 3), dtype=np.float64)
# Mark missing
missing = ~np.isfinite(D) | (D <= 0)
np.fill_diagonal(missing, False)
# Fill by same-|i-j| median
if missing.any():
for sep in range(1, n):
idx_i, idx_j = np.where(
(np.abs(np.arange(n)[:, None] - np.arange(n)[None, :]) == sep)
)
good = ~missing[idx_i, idx_j]
if good.any():
med = float(np.median(D[idx_i[good], idx_j[good]]))
fill = (
(np.abs(np.arange(n)[:, None] - np.arange(n)[None, :]) == sep)
& missing
)
D[fill] = med
else:
# Use global median as last-resort fill
global_med = float(np.median(D[~missing])) if (~missing).any() else 1.0
fill = (
(np.abs(np.arange(n)[:, None] - np.arange(n)[None, :]) == sep)
& missing
)
D[fill] = global_med
# Classical MDS: B = -0.5 J D² J, eigendecompose, top-3
D2 = D * D
J = np.eye(n) - np.ones((n, n)) / n
B = -0.5 * J @ D2 @ J
B = 0.5 * (B + B.T) # numerical symmetry
w, v = np.linalg.eigh(B)
idx = np.argsort(w)[::-1][:3]
top_w = np.clip(w[idx], 0.0, None)
coords = v[:, idx] * np.sqrt(top_w)[None, :]
return coords.astype(np.float64)
def _resolve_device(spec: str) -> Tuple[torch.device, torch.dtype]:
"""Return ``(device, dtype)`` — MPS is float32-only, others float64."""
if spec == "cuda":
return torch.device("cuda"), torch.float64
if spec == "mps":
return torch.device("mps"), torch.float32
if spec == "cpu":
return torch.device("cpu"), torch.float64
# auto
if torch.cuda.is_available():
return torch.device("cuda"), torch.float64
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps"), torch.float32
return torch.device("cpu"), torch.float64
[docs]
def reconstruct_tad_level(
contacts: np.ndarray,
fish_distances: Optional[np.ndarray],
params: Optional[Stage1Params] = None,
initial_coords: Optional[np.ndarray] = None,
return_all: bool = False,
) -> Tuple[np.ndarray, dict]:
"""Stage-1 TAD-level optimisation.
Parameters
----------
contacts : ndarray (N, N)
Inter-TAD Hi-C contact matrix (sums over TAD windows). Raw or
normalised — row-normalisation happens inside ``C_1``.
fish_distances : ndarray (N, N), optional
FISH-measured pairwise distances between TAD centres. ``NaN``
or non-positive entries are ignored. If ``None``, the FISH
term is dropped.
params : Stage1Params
initial_coords : ndarray (N, 3) or (n_ensemble, N, 3), optional
Override the Gaussian initialisation.
return_all : bool
If True, return the full ensemble ``(n_ensemble, N, 3)``;
otherwise return the best single model ``(N, 3)``.
Returns
-------
coords : ndarray
info : dict with keys ``final_loss`` (per-ensemble),
``best_ensemble_idx``, ``loss_history`` (shape
``(n_iter, n_ensemble)``), and the split C_1 / C_2 / C_3 at
the end.
"""
params = params or Stage1Params()
device, dtype = _resolve_device(params.device)
N = contacts.shape[0]
K = params.n_ensemble
# --- Tensors ---
C = torch.as_tensor(contacts, dtype=dtype, device=device)
F = None
if fish_distances is not None:
F = torch.as_tensor(fish_distances, dtype=dtype, device=device)
if initial_coords is not None:
init = torch.as_tensor(initial_coords, dtype=dtype, device=device)
if init.dim() == 2:
init = init.unsqueeze(0).expand(K, -1, -1).clone()
elif init.dim() == 3:
K = init.shape[0]
coords = init.clone()
elif params.init_from_mds and fish_distances is not None:
mds_seed = _mds_seed(
np.asarray(fish_distances, dtype=np.float64),
np.asarray(contacts, dtype=np.float64),
)
torch.manual_seed(0)
seed = torch.as_tensor(mds_seed, dtype=dtype, device=device)
# Replicate across ensemble with small perturbation for diversity
coords = seed.unsqueeze(0).expand(K, -1, -1).clone()
coords = coords + 0.1 * torch.randn_like(coords)
else:
# MPS RNG doesn't support manual_seed on generators yet; seed torch global
torch.manual_seed(0)
coords = params.init_scale * torch.randn(
K, N, 3, dtype=dtype, device=device,
)
coords.requires_grad_(True)
# --- Optimiser ---
if params.optimiser == "lbfgs":
opt = torch.optim.LBFGS(
[coords], lr=params.lr, max_iter=20, line_search_fn="strong_wolfe",
)
else:
opt = torch.optim.Adam([coords], lr=params.lr)
history = np.zeros((params.n_iter, K), dtype=np.float64)
def _total_loss(x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
l1 = hic_kl_loss(x, C)
l2 = polymer_energy(
x,
bond_length=params.bond_length,
bond_weight=params.bond_weight,
angle_weight=params.angle_weight,
exclude_weight=params.exclude_weight,
exclude_radius=params.exclude_radius,
)
l3 = (
fish_distance_loss(x, F)
if F is not None
else torch.zeros_like(l1)
)
total = l1 + params.lambda_E * l2 + params.lambda_F * l3
return total, l1, l2, l3
for it in range(params.n_iter):
if params.optimiser == "lbfgs":
def closure():
opt.zero_grad()
total, _, _, _ = _total_loss(coords)
total_sum = total.sum()
total_sum.backward()
return total_sum
opt.step(closure)
with torch.no_grad():
total, _, _, _ = _total_loss(coords)
else:
opt.zero_grad()
total, _, _, _ = _total_loss(coords)
total.sum().backward()
opt.step()
history[it] = total.detach().cpu().numpy()
if params.verbose and (it % max(1, params.n_iter // 10) == 0
or it == params.n_iter - 1):
with torch.no_grad():
_, l1, l2, l3 = _total_loss(coords)
best = int(total.argmin().item())
print(
f"[stage1] iter {it:5d}/{params.n_iter} "
f"total={total[best]:.4f} "
f"C1={l1[best]:.4f} "
f"C2={l2[best]:.4f} "
f"C3={l3[best]:.4f} (ens best={best})"
)
with torch.no_grad():
total, l1, l2, l3 = _total_loss(coords)
final = total.cpu().numpy()
best_idx = int(np.argmin(final))
info = {
"final_loss": final,
"best_ensemble_idx": best_idx,
"loss_history": history,
"C_1": l1.cpu().numpy(),
"C_2": l2.cpu().numpy(),
"C_3": l3.cpu().numpy(),
}
out = coords.detach().cpu().numpy()
if return_all:
return out, info
return out[best_idx], info
__all__ = ["Stage1Params", "reconstruct_tad_level"]