"""jie — spatial genome aligner entry point.
Independent reimplementation of Jia & Ren 2022 (*Nature Biotechnology*
22, doi:10.1038/s41587-022-01568-9). Not derived from
``github.com/b2jia/jie`` (GPL-3.0).
Usage
-----
>>> from uchrom.im.trace import align_spots, SpotAlignerParams
>>> cd = align_spots(spots_df, params=SpotAlignerParams())
Input contract
--------------
``spots_df`` is a long-format ``pandas.DataFrame``, one row per FISH
detection. Required columns:
- ``cell_id`` — grouping key for per-cell tracing
- ``chrom`` — decoded chromosome
- ``start``,``end`` — decoded genomic locus (bp)
- ``x``,``y``,``z`` — 3-D coordinates (nm)
Optional columns:
- ``sigma_x``,``sigma_y``,``sigma_z`` — per-spot axis-wise localisation
uncertainty (nm). If absent, ``params.default_sigma_nm`` is used.
- ``spot_id`` — arbitrary label carried through to the output.
Unlike a FOF-CT table, ``spots_df`` may have **multiple rows per
(cell_id, chrom, start, end)** — the aligner disambiguates them.
Output
------
A :class:`~uchrom.core.ChromData` whose ``spots`` table contains the
kept detections with a populated ``trace_id``. ``cd.traces`` has one
row per fiber with ``cell_id, chrom, mean_edge, sister_group``.
``cd.results['ploidy']`` summarises the per-(cell, chrom) fiber count.
``cd.uns['jie_polymer']`` holds the per-chromosome τ fit.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from ._polymer import (
PERSISTENCE_LENGTH_BP_DEFAULT,
fit_scale_factor,
)
from ._graph import (
iterative_karyotype,
detect_sister_pairs,
)
[docs]
@dataclass
class SpotAlignerParams:
"""Runtime parameters for :func:`align_spots`."""
# ---------- polymer model ----------
persistence_length_bp: float = PERSISTENCE_LENGTH_BP_DEFAULT
tau_nm_per_bp: Optional[float] = None
"""Genomic-to-spatial scale. If ``None``, fit per-chromosome from
the aggregate observed-distance-vs-genomic-distance curve."""
scaling_exponent_fixed: bool = True
"""If True, force α=0.5 (ideal Gaussian chain) in the τ fit. Set
False to fit α jointly — useful for crumpled-globule chromatin."""
default_sigma_nm: float = 50.0
"""Fallback localisation σ when ``sigma_*`` columns absent from
``spots_df`` (axis-wise, nm)."""
# ---------- graph construction ----------
max_skip: int = 3
gap_penalty: float = 3.0
"""Additive per-skip penalty in the Mahalanobis-cost edge weight.
A skip-0 edge (adjacent locus) adds nothing; a skip-2 edge adds
``2 * gap_penalty``. Calibrate against typical ``R² / 2S²``
(around 1.5 for on-chain matches)."""
# ---------- karyotype ----------
max_fibers_per_chrom: int = 8
min_spots_per_fiber: int = 5
mean_edge_weight_cutoff: float = np.inf
# ---------- sister detection ----------
detect_sisters: bool = True
sister_pair_radius_nm: float = 300.0
# ---------- runtime ----------
verbose: bool = False
# ----------------------------------------------------------------------
# Per-chromosome τ fitting (dataset-wide)
# ----------------------------------------------------------------------
def _fit_tau_global(
spots_df: pd.DataFrame,
params: SpotAlignerParams,
) -> Dict[str, Tuple[float, float]]:
"""Fit τ per chromosome from the aggregate spot distribution.
For each chromosome, uses all (cell, locus_i, locus_j) pairs where
at least one spot exists at each locus in the same cell, and takes
the **median** observed distance vs genomic separation across all
such pairs. Regression is on ``log(<R>)`` vs ``log L``.
Returns
-------
Dict[str, (tau, alpha)]
"""
out: Dict[str, Tuple[float, float]] = {}
for chrom, chrom_df in spots_df.groupby("chrom", observed=True):
# Use one spot per (cell, locus) — pick the first for fitting.
# (τ-fitting doesn't require correct assignment, only rough scale.)
agg = chrom_df.groupby(
["cell_id", "start", "end"], observed=True
).first().reset_index()
Rvals: List[float] = []
Lvals: List[float] = []
for _, cell_spots in agg.groupby("cell_id", observed=True):
xyz = cell_spots[["x", "y", "z"]].to_numpy()
mids = (cell_spots["start"].to_numpy()
+ cell_spots["end"].to_numpy()) / 2
n = xyz.shape[0]
if n < 3:
continue
for i in range(n):
for j in range(i + 1, n):
L = abs(mids[j] - mids[i])
if L <= 0:
continue
R = float(np.linalg.norm(xyz[i] - xyz[j]))
Rvals.append(R)
Lvals.append(float(L))
if len(Rvals) < 10:
out[str(chrom)] = (0.01, 0.5) # fallback
continue
tau, alpha = fit_scale_factor(
np.asarray(Rvals), np.asarray(Lvals),
l_p_bp=params.persistence_length_bp,
fix_exponent=params.scaling_exponent_fixed,
)
out[str(chrom)] = (tau, alpha)
return out
# ----------------------------------------------------------------------
# Per (cell, chrom) tracing
# ----------------------------------------------------------------------
def _trace_one(
sub_df: pd.DataFrame,
chrom: str,
tau_nm_per_bp: float,
params: SpotAlignerParams,
) -> Tuple[List[dict], List[int], List[Tuple[int, int]]]:
"""Trace one cell's one chromosome.
Returns
-------
fibers : list of dict
One per called fiber, with keys ``spots`` (indices into
``sub_df``), ``path_weight``, ``mean_edge``.
sister_group : list of int
Group index per fiber (same index = sisters).
bin_ids : list of (start, end)
Ordered unique (start, end) in ``sub_df``.
"""
# Build unique locus list (ordered by start)
uniq = (
sub_df[["start", "end"]]
.drop_duplicates()
.sort_values("start", kind="mergesort")
.reset_index(drop=True)
)
bin_ids = [(int(s), int(e)) for s, e in zip(uniq["start"], uniq["end"])]
bin_bp = np.array([0.5 * (s + e) for s, e in bin_ids], dtype=np.float64)
locus_to_idx = {(int(s), int(e)): i for i, (s, e) in enumerate(bin_ids)}
spot_xyz = sub_df[["x", "y", "z"]].to_numpy(dtype=np.float64)
spot_locus = np.array(
[locus_to_idx[(int(r["start"]), int(r["end"]))]
for _, r in sub_df.iterrows()],
dtype=np.int64,
)
if {"sigma_x", "sigma_y", "sigma_z"}.issubset(sub_df.columns):
sigma = sub_df[["sigma_x", "sigma_y", "sigma_z"]].to_numpy(
dtype=np.float64
)
else:
sigma = np.full(spot_xyz.shape[0], params.default_sigma_nm,
dtype=np.float64)
fibers = iterative_karyotype(
spot_xyz, spot_locus, bin_bp, sigma,
l_p_bp=params.persistence_length_bp,
tau_nm_per_bp=tau_nm_per_bp,
max_skip=params.max_skip,
gap_penalty=params.gap_penalty,
max_fibers=params.max_fibers_per_chrom,
min_spots_per_fiber=params.min_spots_per_fiber,
mean_edge_weight_cutoff=params.mean_edge_weight_cutoff,
)
if params.detect_sisters and fibers:
groups = detect_sister_pairs(
fibers, spot_xyz, spot_locus,
pair_radius_nm=params.sister_pair_radius_nm,
)
else:
groups = list(range(len(fibers)))
return fibers, groups, bin_ids
# ----------------------------------------------------------------------
# Main entry point
# ----------------------------------------------------------------------
[docs]
def align_spots(
spots_df: pd.DataFrame,
params: Optional[SpotAlignerParams] = None,
tau_by_chrom: Optional[Dict[str, float]] = None,
) -> "ChromData":
"""Assign detected FISH spots to chromatin fibers.
See module docstring for the input-DataFrame contract.
"""
from uchrom.core import ChromData
params = params or SpotAlignerParams()
required = {"cell_id", "chrom", "start", "end", "x", "y", "z"}
missing = required - set(spots_df.columns)
if missing:
raise ValueError(f"spots_df missing required columns: {missing}")
# ---------------------------------------------------------
# τ fitting
# ---------------------------------------------------------
if tau_by_chrom is None:
if params.tau_nm_per_bp is None:
if params.verbose:
print("[jie] fitting τ per chromosome...")
fits = _fit_tau_global(spots_df, params)
tau_by_chrom = {c: t for c, (t, _) in fits.items()}
uns_polymer = {
c: {"tau_nm_per_bp": float(t),
"scaling_exponent": float(a),
"l_p_bp": float(params.persistence_length_bp)}
for c, (t, a) in fits.items()
}
else:
tau_by_chrom = {
str(c): float(params.tau_nm_per_bp)
for c in spots_df["chrom"].unique()
}
uns_polymer = {
c: {"tau_nm_per_bp": float(params.tau_nm_per_bp),
"scaling_exponent": 0.5,
"l_p_bp": float(params.persistence_length_bp)}
for c in tau_by_chrom
}
else:
uns_polymer = {
c: {"tau_nm_per_bp": float(t),
"scaling_exponent": 0.5,
"l_p_bp": float(params.persistence_length_bp)}
for c, t in tau_by_chrom.items()
}
if params.verbose:
print(f"[jie] τ per chrom: "
f"{ {c: round(t, 5) for c, t in tau_by_chrom.items()} }")
# ---------------------------------------------------------
# Per (cell, chrom) trace assignment
# ---------------------------------------------------------
spots_out: List[dict] = []
coords_out: List[np.ndarray] = []
traces_out: List[dict] = []
ploidy_out: List[dict] = []
next_global_tid = 0
next_global_sg = 0
total_groups = len(list(spots_df.groupby(["cell_id", "chrom"])))
for gi, ((cell_id, chrom), sub_df) in enumerate(
spots_df.groupby(["cell_id", "chrom"], observed=True)
):
sub_df = sub_df.reset_index(drop=True)
tau = tau_by_chrom.get(str(chrom), 0.01)
fibers, sis_groups, _ = _trace_one(
sub_df, str(chrom), tau, params,
)
# Re-number sister groups into the global namespace
sg_map = {}
for sg in sis_groups:
if sg not in sg_map:
sg_map[sg] = next_global_sg
next_global_sg += 1
for fi, fiber in enumerate(fibers):
tid = next_global_tid
next_global_tid += 1
sg_id = sg_map[sis_groups[fi]]
for spot_idx in fiber["spots"]:
row = sub_df.iloc[spot_idx]
spot_entry = {
"chrom": str(chrom),
"start": int(row["start"]),
"end": int(row["end"]),
"trace_id": tid,
"cell_id": cell_id,
}
if "spot_id" in sub_df.columns:
spot_entry["spot_id"] = row["spot_id"]
spots_out.append(spot_entry)
coords_out.append(np.asarray([
row["x"], row["y"], row["z"],
], dtype=np.float64))
traces_out.append({
"trace_id": tid,
"cell_id": cell_id,
"chrom": str(chrom),
"n_spots": len(fiber["spots"]),
"path_weight": float(fiber["path_weight"]),
"mean_edge": float(fiber["mean_edge"]),
"sister_group": sg_id,
})
ploidy_out.append({
"cell_id": cell_id,
"chrom": str(chrom),
"n_fibers": len(fibers),
"n_sister_groups": len(sg_map),
"tau_nm_per_bp": float(tau),
})
if params.verbose and (gi + 1) % 20 == 0:
print(f"[jie] traced {gi + 1}/{total_groups} (cell, chrom) groups")
if not coords_out:
raise ValueError(
"No fibers called — check inputs, min_spots_per_fiber, "
"and gap_penalty calibration."
)
coords = np.vstack(coords_out)
spots = pd.DataFrame(spots_out)
traces = pd.DataFrame(traces_out)
ploidy = pd.DataFrame(ploidy_out)
uns = {
"jie_polymer": uns_polymer,
"jie_params": {
"persistence_length_bp": params.persistence_length_bp,
"max_skip": params.max_skip,
"gap_penalty": params.gap_penalty,
"default_sigma_nm": params.default_sigma_nm,
"min_spots_per_fiber": params.min_spots_per_fiber,
},
}
cd = ChromData(
coords=coords,
spots=spots,
traces=traces,
uns=uns,
results={"ploidy": ploidy},
)
if params.verbose:
print(f"[jie] → {len(traces)} fibers across "
f"{spots_df['cell_id'].nunique()} cells")
return cd
__all__ = [
"SpotAlignerParams",
"align_spots",
]