Source code for uchrom.recon.fish._losses

"""Loss primitives for GEM-FISH.

Independent re-implementation of the loss terms from Abbas et al. 2019
(*Nat. Commun.*, doi:10.1038/s41467-019-10005-6).  Each primitive takes
a coordinate tensor of shape ``(*batch, n_bins, 3)`` (PyTorch) and
returns a scalar loss per batch item.  A leading batch dimension is
supported for ensemble training — stacking K independent models on the
first axis lets them share autograd and optimiser state at zero extra
code cost.

Summary of terms
----------------
- ``C_1`` — KL divergence between row-normalised Hi-C interaction
  frequencies and row-normalised inverse reconstructed distances.  This
  is the paper's Hi-C structural consistency term (Eqn. 3).
- ``C_2`` — polymer conformation energy (bond + angle + excluded
  volume).  The paper cites prior GEM / uchrom.recon.sc.gem for the
  functional form; implemented here as a sum of three standard
  harmonic / hinge terms.
- ``C_3`` — sum of squared errors between reconstructed pairwise
  distances and FISH-measured distances on TAD centres (Eqn. 6).
- ``C_4`` — L1 penalty on the deviation of squared radius-of-gyration
  from a FISH-derived target (Eqn. 7).

All lengths are in the coordinate system of ``coords`` (typically nm),
with the Hi-C-derived distances scaled to the same units.
"""

from __future__ import annotations

from typing import Optional

import torch
from torch import Tensor


_EPS = 1e-12


def _pairwise_dist(coords: Tensor) -> Tensor:
    """Return ``(..., N, N)`` Euclidean distances for ``coords``
    shaped ``(..., N, 3)``.  Diagonal is exactly zero.
    """
    diff = coords.unsqueeze(-2) - coords.unsqueeze(-3)  # (..., N, N, 3)
    return torch.sqrt((diff * diff).sum(-1).clamp_min(0.0) + _EPS)


# ----------------------------------------------------------------------
# C_1 — Hi-C KL divergence
# ----------------------------------------------------------------------

[docs] def hic_kl_loss( coords: Tensor, contacts: Tensor, mask: Optional[Tensor] = None, ) -> Tensor: """Row-wise KL divergence between normalised Hi-C and normalised inverse reconstructed distances. ``P_ij = f_ij / Σ_{k≠i} f_ik`` and ``Q_ij = (1 + d_ij)^{-1} / Σ_{k≠i} (1 + d_ik)^{-1}``, where ``d_ij = ||s_i - s_j||`` is the reconstructed 3-D distance. Returns ``Σ_i Σ_{j≠i} P_ij log(P_ij / Q_ij)`` summed over ``i``. Parameters ---------- coords : Tensor (..., N, 3) contacts : Tensor (N, N) Raw Hi-C interaction frequencies. Symmetric; diagonal ignored. mask : Tensor (N, N), optional Boolean mask of ``(i, j)`` pairs to include (for e.g. dropping unmappable bins). Default: all off-diagonal pairs. """ N = coords.shape[-2] # Zero the diagonal of contacts and mask contacts = contacts.to(coords.dtype).to(coords.device) diag = torch.eye(N, dtype=torch.bool, device=contacts.device) f = contacts.masked_fill(diag, 0.0) if mask is not None: f = f * mask.to(f.dtype) # Row-normalise Hi-C: P_ij = f_ij / Σ_j f_ij row_sum = f.sum(-1, keepdim=True).clamp_min(_EPS) P = f / row_sum d = _pairwise_dist(coords) inv = 1.0 / (1.0 + d) # Zero diagonal of inv so its row-sum matches P's support inv = inv.masked_fill(diag, 0.0) if mask is not None: inv = inv * mask.to(inv.dtype) Q = inv / inv.sum(-1, keepdim=True).clamp_min(_EPS) # KL(P||Q) = Σ P log(P/Q), summed over j then i. # When P_ij = 0 the term is 0 by convention; we clamp P to avoid log(0). valid = P > 0 logterm = torch.where( valid, torch.log(P.clamp_min(_EPS) / Q.clamp_min(_EPS)), torch.zeros_like(P), ) return (P * logterm).sum(dim=(-1, -2))
# ---------------------------------------------------------------------- # C_2 — polymer energy (bond + angle + excluded volume) # ----------------------------------------------------------------------
[docs] def polymer_energy( coords: Tensor, bond_length: float = 1.0, bond_weight: float = 1.0, angle_weight: float = 0.1, exclude_weight: float = 1.0, exclude_radius: float = 1.0, ) -> Tensor: """Bond-stretching + angle-bending + excluded-volume energy. - Bond: ``Σ_i (||s_{i+1} - s_i|| - b)²`` - Angle: ``Σ_i θ²`` where ``θ`` is the angle between successive bond vectors. - Exclude: ``Σ_{i<j} max(0, r - ||s_i - s_j||)²`` (hinge). Matches the functional form used in the existing ``uchrom.recon.sc.gem`` Taichi kernels. """ N = coords.shape[-2] if N < 2: return torch.zeros(coords.shape[:-2], device=coords.device, dtype=coords.dtype) # Bond bond_vec = coords[..., 1:, :] - coords[..., :-1, :] bond_len = torch.sqrt((bond_vec * bond_vec).sum(-1).clamp_min(0.0) + _EPS) E_bond = bond_weight * ((bond_len - bond_length) ** 2).sum(-1) # Angle — cosine between consecutive bond vectors if N >= 3: v1 = bond_vec[..., :-1, :] v2 = bond_vec[..., 1:, :] v1n = v1 / (v1.norm(dim=-1, keepdim=True).clamp_min(_EPS)) v2n = v2 / (v2.norm(dim=-1, keepdim=True).clamp_min(_EPS)) cos = (v1n * v2n).sum(-1).clamp(-1.0 + 1e-7, 1.0 - 1e-7) angle = torch.acos(cos) E_angle = angle_weight * (angle ** 2).sum(-1) else: E_angle = torch.zeros_like(E_bond) # Excluded volume — hinge penalty for pairs closer than exclude_radius d = _pairwise_dist(coords) triu = torch.triu(torch.ones_like(d), diagonal=1) overlap = torch.clamp_min(exclude_radius - d, 0.0) * triu E_exclude = exclude_weight * ((overlap) ** 2).sum(dim=(-1, -2)) return E_bond + E_angle + E_exclude
# ---------------------------------------------------------------------- # C_3 — FISH distance constraints # ----------------------------------------------------------------------
[docs] def fish_distance_loss( coords: Tensor, fish_distances: Tensor, mask: Optional[Tensor] = None, ) -> Tensor: """``Σ_{i,j ∈ M} (||s_i - s_j|| - F_ij)²``. Parameters ---------- coords : Tensor (..., N, 3) fish_distances : Tensor (N, N) FISH-measured average pairwise distances. Symmetric; NaN or non-finite entries are ignored. mask : Tensor (N, N), optional Extra Boolean mask — set to ``False`` where no FISH measurement is available. If omitted, finite entries of ``fish_distances`` define the mask automatically. """ F = fish_distances.to(coords.dtype).to(coords.device) if mask is None: mask = torch.isfinite(F) & (F > 0) mask = mask.to(coords.device) & ~torch.eye( F.shape[-1], dtype=torch.bool, device=coords.device, ) d = _pairwise_dist(coords) diff = torch.where(mask, d - F, torch.zeros_like(d)) # Count each (i, j) pair once: divide by 2 (symmetric matrix) return 0.5 * (diff * diff).sum(dim=(-1, -2))
# ---------------------------------------------------------------------- # C_4 — Radius of gyration L1 penalty # ----------------------------------------------------------------------
[docs] def rg_squared(coords: Tensor) -> Tensor: """Squared radius of gyration ``(1/N) Σ ||s_i - s̄||²``.""" centroid = coords.mean(dim=-2, keepdim=True) d2 = ((coords - centroid) ** 2).sum(-1) return d2.mean(dim=-1)
[docs] def rg_loss( coords: Tensor, target_rg_sq: Tensor, ) -> Tensor: """``|R_g² - R̂_g²|`` — L1 penalty on the squared-Rg deviation. Parameters ---------- coords : Tensor (..., N, 3) target_rg_sq : Tensor (...) or scalar FISH-derived ``R̂_g²``. Broadcast over the batch dimension. """ r2 = rg_squared(coords) target = torch.as_tensor( target_rg_sq, dtype=coords.dtype, device=coords.device, ) return (r2 - target).abs()
__all__ = [ "hic_kl_loss", "polymer_energy", "fish_distance_loss", "rg_squared", "rg_loss", ]