"""Stage 3 of GEM-FISH — assemble per-TAD intra models under the
TAD-level scaffold.
For each TAD i with intra-TAD coordinates ``y^{(i)}`` and a target
TAD-centre position ``s_i`` from Stage 1:
1. **Translate** each TAD so its centroid (mean of its bin coords)
coincides with ``s_i``.
2. **Rotate / reflect** each TAD (except the first) to minimise the
distance between its start point and the previous TAD's end point,
penalised against the expected inter-TAD gap ``d_{i-1,i}``
(computed from Hi-C contact counts via Eqn. 10).
Step 2 is solved independently per TAD: with the rigid-body transform
constrained to keep the centroid fixed at ``s_i``, the optimal
rotation/reflection that minimises ``||y^{(i)}_{start} − (y^{(i-1)}_{end}
+ d̂_{i-1,i})||²`` is a closed-form Kabsch problem on a **single vector
pair**, but since we also have the freedom to reflect and the objective
is a single squared distance, the simplest robust thing is to rotate
the TAD around its centroid so that the vector ``centroid → start``
points toward the desired target. That's a two-atom Kabsch, done via
SVD of the outer product.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
[docs]
@dataclass
class AssemblyParams:
"""Parameters for the Stage-3 assembly step."""
allow_reflection: bool = True
"""If True, accept reflections (negative-determinant rotations)
when they reduce the boundary-gap objective."""
endpoint_k: int = 1
"""Number of bins from each TAD end used as 'anchor' points when
computing the Kabsch alignment. ``1`` uses just the first / last
bin; ``3`` uses the first/last 3 bins (robust to noise)."""
match_scales: bool = True
"""If True, uniformly rescale the Stage-1 TAD-centre layout so the
mean inter-TAD gap matches the intra-TAD Rg scale (roughly tangent
packing of TADs). Without this, Stage 1 and Stage 2 end up on
incompatible scales: Stage-1 centres drift apart under minimal
polymer constraint while each Stage-2 intra-TAD cloud optimises
to unit bond length, giving tight blobs with huge bridges between
them. The intra-TAD geometry (which Stage 2 optimised against
real Hi-C) is preserved; only the global layout is rescaled."""
centre_gap_factor: float = 2.0
"""When ``match_scales=True``, the TAD centres are rescaled so
that ``mean_inter_TAD_gap = centre_gap_factor * mean_intra_TAD_Rg``.
``2.0`` corresponds to roughly-tangent TAD packing; ``1.5`` gives
slight overlap between neighbours; ``3.0`` gives visible spacing."""
iterative: bool = True
"""If True, after the single-shot Kabsch rotation run the upstream
GEM-FISH iterative gradient descent (Abbas 2019 Eqn. 13): for each
adjacent-TAD pair, rotate TAD i+1 around its centre so the gap
between ``y_end[i]`` and ``y_start[i+1]`` converges toward
``d_prior[i, i+1]``. Repeats with reflection retry for TADs whose
boundary gap has > ``reflect_threshold`` relative error."""
inner_max_iter: int = 5000
"""Max gradient-descent iterations per inner assembly pass."""
outer_max_iter: int = 5
"""Max outer passes (each can reflect bad TADs and re-run inner)."""
alpha: float = 0.1
"""Gradient-descent learning rate for the iterative assembly."""
reflect_threshold: float = 0.2
"""TADs whose adjacent-boundary relative error exceeds this
threshold are reflected through their centre before the next outer
pass (matches the upstream reflection-retry strategy)."""
convergence_abs: float = 0.1
"""Convergence: stop inner loop when ``Σ|d_prior - df| < convergence_abs``."""
def _kabsch(
P: np.ndarray, Q: np.ndarray, allow_reflection: bool = True,
) -> np.ndarray:
"""Optimal rotation R with ``min ||P R − Q||²``.
Both ``P`` and ``Q`` are ``(k, 3)`` matrices of corresponding
points, **already centred** (their centroids are at the origin).
``R`` is right-applied: new_coords = old_coords @ R.
"""
H = P.T @ Q
U, _, Vt = np.linalg.svd(H)
d = np.sign(np.linalg.det(U @ Vt))
D = np.eye(3)
if not allow_reflection and d < 0:
D[-1, -1] = d
return U @ D @ Vt
[docs]
def assemble(
tad_centres: np.ndarray,
intra_tad_coords: List[np.ndarray],
inter_tad_distances: Optional[np.ndarray] = None,
params: Optional[AssemblyParams] = None,
) -> np.ndarray:
"""Stitch intra-TAD conformations together into a single chain.
Parameters
----------
tad_centres : ndarray (n_tads, 3)
Stage-1 TAD-centre coordinates.
intra_tad_coords : list of ndarray (n_bins_i, 3)
Stage-2 intra-TAD coordinate clouds (one per TAD).
inter_tad_distances : ndarray (n_tads, n_tads), optional
Pairwise expected distances used for Stage-3 alignment. Only
the immediate-neighbour entries ``d[i, i+1]`` are consulted in
this simple implementation. If omitted, we fall back to using
the Stage-1 centre distances directly.
params : AssemblyParams
Returns
-------
coords : ndarray (sum n_bins_i, 3)
Concatenated, rigid-body-aligned per-bin coordinates. Order is
TAD-0 bins first, then TAD-1, …
"""
params = params or AssemblyParams()
n_tads = tad_centres.shape[0]
assert len(intra_tad_coords) == n_tads, (
"intra_tad_coords length must match tad_centres length"
)
# Harmonise Stage-1 (TAD-centre) and Stage-2 (intra-TAD) scales.
# The intra-TAD geometry carries the information that Stage 2
# optimised against real Hi-C, so we leave the clouds intact and
# rescale the TAD-centre layout instead. Target: mean inter-TAD
# gap ≈ ``centre_gap_factor * mean intra-TAD Rg`` (tangent-packed
# TAD spheres for factor=2).
if params.match_scales and n_tads >= 2:
intra_rgs = []
for clump in intra_tad_coords:
cl = np.asarray(clump)
if cl.shape[0] >= 2:
centroid = cl.mean(axis=0)
rg = float(np.sqrt(((cl - centroid) ** 2).sum(axis=1).mean()))
intra_rgs.append(rg)
if intra_rgs:
mean_rg = float(np.mean(intra_rgs))
gaps = np.linalg.norm(
tad_centres[1:] - tad_centres[:-1], axis=1,
)
mean_gap = float(np.mean(gaps))
if mean_gap > 1e-9:
target_gap = params.centre_gap_factor * mean_rg
tad_centres = tad_centres * (target_gap / mean_gap)
if inter_tad_distances is not None:
inter_tad_distances = (
inter_tad_distances * (target_gap / mean_gap)
)
# Stage 3a: centre each intra-TAD cloud on its (optionally rescaled)
# TAD centre — the intra geometry is preserved.
placed = []
for i, clump in enumerate(intra_tad_coords):
cl = np.asarray(clump, dtype=np.float64)
cl = cl - cl.mean(axis=0, keepdims=True) + tad_centres[i]
placed.append(cl)
# Stage 3b: rotate each TAD (i ≥ 1) so its start anchor is closest
# to the previous TAD's end anchor + expected-gap heading
k = max(1, int(params.endpoint_k))
for i in range(1, n_tads):
prev = placed[i - 1]
curr = placed[i]
if prev.shape[0] < k or curr.shape[0] < k:
continue
# Anchor points
prev_end = prev[-k:].mean(axis=0)
curr_start_now = curr[:k].mean(axis=0)
# Expected heading vector from prev_end toward the curr centre
heading = tad_centres[i] - prev_end
hn = np.linalg.norm(heading)
if hn < 1e-9:
continue
heading_unit = heading / hn
# Decide the target location for the "start anchor" of the
# current TAD: sit on the line from prev_end toward curr centre,
# at a distance given by inter_tad_distances[i-1, i] if set.
if inter_tad_distances is not None:
target_gap = float(inter_tad_distances[i - 1, i])
else:
target_gap = hn # fallback: use centre-to-centre distance
target_start = prev_end + heading_unit * target_gap
# Current anchor vector (centred on the TAD centre)
v_curr = curr_start_now - tad_centres[i]
v_target = target_start - tad_centres[i]
# Kabsch on a single anchor pair (k=1 semantics preserved even
# with k>1 because we collapse anchors to their mean).
P = v_curr[np.newaxis, :]
Q = v_target[np.newaxis, :]
R = _kabsch(P, Q, allow_reflection=params.allow_reflection)
# Apply rotation around the TAD centre
placed[i] = (curr - tad_centres[i]) @ R + tad_centres[i]
# Stage 3c: optional iterative gradient-descent + reflection retry.
if params.iterative and n_tads >= 2 and inter_tad_distances is not None:
placed = _iterative_rotate_assemble(
placed, tad_centres, inter_tad_distances,
params.inner_max_iter,
params.outer_max_iter,
params.alpha,
params.reflect_threshold,
params.convergence_abs,
)
return np.vstack(placed)
def _rotation_matrix(theta: float, axis: np.ndarray) -> np.ndarray:
"""Rodrigues' rotation about unit vector ``axis`` by angle ``theta``."""
ax = axis / max(np.linalg.norm(axis), 1e-12)
c, s = np.cos(theta), np.sin(theta)
x, y, z = ax
C = 1.0 - c
return np.array([
[c + x * x * C, x * y * C - z * s, x * z * C + y * s],
[y * x * C + z * s, c + y * y * C, y * z * C - x * s],
[z * x * C - y * s, z * y * C + x * s, c + z * z * C],
])
def _rotate_around_centre(
segment: np.ndarray, centre: np.ndarray, R: np.ndarray,
) -> np.ndarray:
return (segment - centre) @ R.T + centre
def _iterative_rotate_assemble(
placed: list,
tad_centres: np.ndarray,
inter_tad_distances: np.ndarray,
inner_max_iter: int,
outer_max_iter: int,
alpha: float,
reflect_threshold: float,
convergence_abs: float,
) -> list:
"""Upstream GEM-FISH iterative assembly (Abbas 2019 Eqn. 13 + reflection).
For each adjacent TAD pair ``(i, i+1)``:
- Measure boundary distance ``df[i] = ||y_end[i] - y_start[i+1]||``
- Compute gradient ``der[i] = 2(df[i] − d_prior[i])/df[i] × (y_start[i+1] − y_end[i])``
- Move ``y_start[i+1]`` toward ``y_start[i+1] − alpha × der[i]``
- Rotate TAD ``i+1`` around its centre so its first-point lands at
the new target position.
Outer loop: any TAD whose boundary-gap relative error is above
``reflect_threshold`` gets reflected (x,y mirror through its centre)
and the inner loop restarts.
"""
n_tads = len(placed)
# d_prior for adjacent pairs only
d_prior = np.array([
float(inter_tad_distances[i, i + 1]) for i in range(n_tads - 1)
])
# If any are 0 or non-finite, leave as no-op (matches upstream when
# both Hi-C and genomic fallback yield 0)
d_prior = np.where(np.isfinite(d_prior) & (d_prior > 0), d_prior, np.nan)
placed = [np.asarray(p, dtype=np.float64).copy() for p in placed]
lr = alpha
for outer in range(outer_max_iter):
# --- Inner gradient descent ---
prev_err_sum = float("inf")
stagnant = 0
for inner in range(inner_max_iter):
df = np.zeros(n_tads - 1)
der = np.zeros((n_tads - 1, 3))
ok_pairs = np.zeros(n_tads - 1, dtype=bool)
for i in range(n_tads - 1):
if placed[i].shape[0] < 1 or placed[i + 1].shape[0] < 1:
continue
y_end = placed[i][-1]
y_start = placed[i + 1][0]
diff = y_start - y_end
d = np.linalg.norm(diff)
df[i] = d
if not np.isfinite(d_prior[i]) or d < 1e-9:
continue
factor = 2.0 * (d - d_prior[i]) / d
der[i] = factor * diff
ok_pairs[i] = True
err = np.where(
np.isfinite(d_prior), np.abs(d_prior - df), 0.0,
)
err_sum = float(err.sum())
if err_sum < convergence_abs:
break
if abs(err_sum - prev_err_sum) < 0.01:
stagnant += 1
else:
stagnant = 0
if stagnant > 50:
break
prev_err_sum = err_sum
# Every 1000 iters, escalate learning rate (as upstream does)
if (inner + 1) % 1000 == 0:
lr = lr * 2.0
# Apply rotations
for i in range(n_tads - 1):
if not ok_pairs[i]:
continue
y_start = placed[i + 1][0]
y_new = y_start - lr * der[i]
v1 = y_start - tad_centres[i + 1]
v2 = y_new - tad_centres[i + 1]
n1 = np.linalg.norm(v1)
n2 = np.linalg.norm(v2)
if n1 < 1e-9 or n2 < 1e-9:
continue
cos_a = float(np.clip(
np.dot(v1, v2) / (n1 * n2), -1.0, 1.0,
))
angle = float(np.arccos(cos_a))
if angle < 1e-3:
continue
axis = np.cross(v1, v2)
if np.linalg.norm(axis) < 1e-9:
continue
R = _rotation_matrix(angle, axis)
placed[i + 1] = _rotate_around_centre(
placed[i + 1], tad_centres[i + 1], R,
)
# --- Outer loop: reflection retry for high-error TADs ---
final_df = np.zeros(n_tads - 1)
for i in range(n_tads - 1):
final_df[i] = np.linalg.norm(
placed[i + 1][0] - placed[i][-1]
)
with np.errstate(divide="ignore", invalid="ignore"):
relative = np.where(
np.isfinite(d_prior) & (d_prior > 0),
np.abs(d_prior - final_df) / d_prior,
0.0,
)
reflected_any = False
for i in range(n_tads - 1):
if relative[i] > reflect_threshold:
# Reflect TAD i+1 through its centre in x and y
centre = tad_centres[i + 1]
reflected = placed[i + 1].copy()
reflected[:, 0] = 2 * centre[0] - reflected[:, 0]
reflected[:, 1] = 2 * centre[1] - reflected[:, 1]
placed[i + 1] = reflected
reflected_any = True
if not reflected_any:
break
return placed
__all__ = ["AssemblyParams", "assemble", "gap_distance_from_contacts"]