Source code for uchrom.fea.arc

"""ArcFISH-style axis-wise preprocessing for chromatin tracing data.

References
----------
Yu H. *et al.* *Accurate and robust 3D genome feature discovery from
multiplexed DNA FISH*, bioRxiv 2025.11.26.690837v1.

Independent implementation in :mod:`uchrom` — not derived from the GPL-3.0
ArcFISH source.

Pipeline (per chromosome)
-------------------------
1. ``axis_variance_cube``
   Builds ``(3, n_bins, n_bins)`` per-axis pairwise variance + count cubes
   from ``ChromData`` spots.  Each trace contributes a rank-1 outer
   difference for each axis; aggregation is NaN-aware.
2. ``filter_normalize``
   Two-pass LOWESS stratification on log(1D genomic distance):
     - first pass: flag entries whose per-pair squared deviation is more
       than ``k_sigma`` × stratified std as outliers and NaN them;
     - second pass: refit LOWESS on the cleaned variances to give each
       entry a genome-distance-matched expectation, then normalise.
3. ``axis_weight``
   Returns a 3-vector of weights (sum 1) inversely proportional to the
   per-axis trace-variance median — the exact weighting used by the
   ACAT combination step in the loop / tad / comp callers.

All tensor-heavy computation runs on a user-selected torch device
(``'auto' | 'cpu' | 'cuda' | 'mps'``).  LOWESS stays on CPU via
``statsmodels`` because it's a non-vectorised kernel smoother whose
input size is ``O(n_bins²)`` (typically ≤ 10 k).
"""

from __future__ import annotations

from typing import Optional

import numpy as np
import pandas as pd

from uchrom.utils import get_device, lowess_log_log
from uchrom.utils.stats import default_float_dtype


# ----------------------------------------------------------------------
# Coordinate cube: (n_traces, n_bins, 3) with NaN padding
# ----------------------------------------------------------------------

def _coord_cube_for_chrom(cd, chrom):
    """Return ``(coords_cube, bin_ids, trace_ids)`` for one chromosome.

    ``coords_cube[t, i, :]`` is the (x, y, z) of trace ``t`` at bin ``i``,
    NaN if the trace missed that bin.  Bins are sorted by ``start``.
    """
    df = cd.spots.copy()
    df["x"] = cd.coords[:, 0]
    df["y"] = cd.coords[:, 1]
    df["z"] = cd.coords[:, 2]
    df = df[df["chrom"].astype(str) == str(chrom)]
    if df.empty:
        raise ValueError(f"No spots found for chromosome {chrom}")

    # Coerce start/end to int to avoid float-keyed lookups
    df = df.assign(
        start=df["start"].astype(np.int64),
        end=df["end"].astype(np.int64),
    )

    bins = (
        df[["start", "end"]]
        .drop_duplicates()
        .sort_values("start", kind="mergesort")
        .reset_index(drop=True)
    )
    bin_keys = [(int(s), int(e)) for s, e in zip(bins["start"], bins["end"])]
    bin_to_idx = {k: i for i, k in enumerate(bin_keys)}
    n_bins = len(bin_keys)

    trace_ids = list(df["trace_id"].unique())
    tid_to_idx = {t: i for i, t in enumerate(trace_ids)}
    n_traces = len(trace_ids)

    cube = np.full((n_traces, n_bins, 3), np.nan, dtype=np.float64)
    ti = np.asarray(
        [tid_to_idx[t] for t in df["trace_id"]], dtype=np.int64
    )
    bi = np.asarray(
        [bin_to_idx[(int(s), int(e))] for s, e in zip(df["start"], df["end"])],
        dtype=np.int64,
    )
    cube[ti, bi, 0] = df["x"].to_numpy()
    cube[ti, bi, 1] = df["y"].to_numpy()
    cube[ti, bi, 2] = df["z"].to_numpy()
    return cube, bin_keys, trace_ids


# ----------------------------------------------------------------------
# Step 1: axis variance cube (3, n_bins, n_bins)
# ----------------------------------------------------------------------

def _pairwise_diff_tensor(cd, chrom: str, device: str = "auto"):
    """Return the full ``(3, n_traces, n_bins, n_bins)`` pairwise difference.

    ``diff[c, t, i, j] = coord_c[t, j] - coord_c[t, i]`` with NaN where a
    trace is missing a bin.  Kept on ``device``; caller is responsible for
    moving to CPU if needed.
    """
    import torch

    cube, bin_ids, trace_ids = _coord_cube_for_chrom(cd, chrom)
    dev = get_device(device)
    dtype = default_float_dtype(dev)

    c = torch.as_tensor(cube, dtype=dtype, device=dev)   # (T, B, 3)
    diff = c.unsqueeze(2) - c.unsqueeze(1)               # (T, B, B, 3)
    diff = diff.permute(3, 0, 1, 2).contiguous()         # (3, T, B, B)
    return diff, bin_ids, trace_ids


[docs] def axis_variance_cube( cd, chrom: str, device: str = "auto", ) -> dict: """Compute per-axis pairwise variance + sample-count cubes. Returns a dict with ``var``, ``count``, ``mean`` (all ``(3, B, B)``) plus ``bin_ids``, ``n_traces``, ``chrom``, and — for downstream :func:`filter_normalize` — the full ``(3, T, B, B)`` pairwise diff tensor on the GPU device under key ``"diff"``. """ import torch diff, bin_ids, trace_ids = _pairwise_diff_tensor(cd, chrom, device) valid = ~torch.isnan(diff) count = valid.sum(dim=1).to(torch.int64) # (3, B, B) # NaN-aware mean + variance along trace dim safe = torch.where(valid, diff, torch.zeros_like(diff)) safe_count = count.clamp(min=1).to(diff.dtype) mean = safe.sum(dim=1) / safe_count dev_sq = torch.where(valid, (diff - mean.unsqueeze(1)) ** 2, torch.zeros_like(diff)) var = dev_sq.sum(dim=1) / safe_count # (3, B, B) no_obs = count == 0 var = torch.where(no_obs, torch.full_like(var, float("nan")), var) mean = torch.where(no_obs, torch.full_like(mean, float("nan")), mean) return { "var": var.detach().cpu().numpy().astype(np.float64), "count": count.detach().cpu().numpy().astype(np.int64), "mean": mean.detach().cpu().numpy().astype(np.float64), "diff": diff, # stays on device "bin_ids": bin_ids, "n_traces": len(trace_ids), "chrom": chrom, }
# ---------------------------------------------------------------------- # Step 2: filter_normalize (2-pass LOWESS) # ---------------------------------------------------------------------- def _bin_midpoints(bin_ids): """Return 1D array of midpoints (start+end)/2 for each bin.""" return np.array([(s + e) * 0.5 for (s, e) in bin_ids], dtype=np.float64) def _genomic_distance_matrix(bin_ids): """``(n_bins, n_bins)`` matrix of |midpoint_i - midpoint_j|.""" mids = _bin_midpoints(bin_ids) return np.abs(mids[:, None] - mids[None, :])
[docs] def filter_normalize( cube: dict, k_sigma: float = 4.0, frac: float = 0.1, ) -> dict: """ArcFISH-style per-trace LOWESS filter + normalise. Operates on the full ``(3, n_traces, n_bins, n_bins)`` pairwise-diff tensor kept on the GPU (under ``cube['diff']``). Two passes: 1. Per-pair ``raw_var = nanmedian(trace_diff²)`` → LOWESS over ``log(genomic_distance)`` → ``strata_std``. Individual trace observations where ``|diff - median(diff)| > k_sigma × strata_std`` are NaN'd in-place in the 4D tensor. 2. After filtering, per-pair ``filtered_var = nanmean((diff - mean)²)`` and per-pair ``count = n_valid`` recomputed. LOWESS again over log(d1d) → expected; normalised variance = filtered / expected. Output (numpy, on CPU): ``var``, ``count`` (refreshed after filter), ``norm_var``, ``expected``, ``raw_var``, ``genomic_distance``. The original 4D tensor under ``"diff"`` is consumed (may be modified). """ import torch diff: "torch.Tensor" = cube["diff"] # (3, T, B, B) on device n_axes, T, B, _ = diff.shape bin_ids = cube["bin_ids"] d1d = _genomic_distance_matrix(bin_ids) # (B, B) numpy iu = np.triu_indices(B, k=1) d_flat = d1d[iu] # -------- Pass 1: raw_var via nanmedian, LOWESS, outlier mask -------- # arr_dev[c, n, i, j] centred deviation from trace median. # raw_var[c, i, j] = nanmedian_n((arr - nanmedian_n arr)²) dev = diff - _nanmedian_along_dim(diff, 1, keepdim=True) # (3, T, B, B) sq = dev ** 2 # nanmedian along trace axis raw_var = _nanmedian_along_dim(sq, 1, keepdim=False) # (3, B, B) # Replace near-zero → NaN raw_var = torch.where(raw_var < 1e-30, torch.full_like(raw_var, float("nan")), raw_var) # LOWESS log(raw_var) ~ log(d1d) raw_var_np = raw_var.detach().cpu().numpy().astype(np.float64) strata_std = np.full_like(raw_var_np, np.nan) for a in range(n_axes): v_flat = raw_var_np[a][iu] fit = lowess_log_log(d_flat, v_flat, frac=frac) std_flat = np.sqrt(np.clip(fit, 1e-30, None)) axm = np.full((B, B), np.nan) axm[iu] = std_flat axm.T[iu] = std_flat strata_std[a] = axm # Per-trace outlier mask: |diff - median_t(diff)| > k_sigma * strata_std strata_std_t = torch.as_tensor( strata_std, dtype=diff.dtype, device=diff.device ) # (3, B, B) # abs_dev[c, n, i, j] abs_dev = (diff - _nanmedian_along_dim(diff, 1, keepdim=True)).abs() threshold = (k_sigma * strata_std_t).unsqueeze(1) # (3, 1, B, B) outlier = abs_dev > threshold # NaN out outliers in-place diff = torch.where(outlier, torch.full_like(diff, float("nan")), diff) # -------- Pass 2: filtered_var via nanmean squared deviation -------- valid = ~torch.isnan(diff) count = valid.sum(dim=1).to(torch.int64) # (3, B, B) safe_count = count.clamp(min=1).to(diff.dtype) safe = torch.where(valid, diff, torch.zeros_like(diff)) mean = safe.sum(dim=1) / safe_count # (3, B, B) dev_sq = torch.where(valid, (diff - mean.unsqueeze(1)) ** 2, torch.zeros_like(diff)) filtered_var = dev_sq.sum(dim=1) / safe_count # (3, B, B) filtered_var = torch.where(count == 0, torch.full_like(filtered_var, float("nan")), filtered_var) fvar_np = filtered_var.detach().cpu().numpy().astype(np.float64) count_np = count.detach().cpu().numpy().astype(np.int64) expected = np.full_like(fvar_np, np.nan) norm_var = np.full_like(fvar_np, np.nan) for a in range(n_axes): v_flat = fvar_np[a][iu] v_flat_ok = np.where(v_flat > 1e-30, v_flat, np.nan) fit = lowess_log_log(d_flat, v_flat_ok, frac=frac) exp_axis = np.full((B, B), np.nan) exp_axis[iu] = fit exp_axis.T[iu] = fit expected[a] = exp_axis with np.errstate(invalid="ignore", divide="ignore"): norm_axis = np.where(exp_axis > 0, fvar_np[a] / exp_axis, np.nan) norm_var[a] = norm_axis # Zero the diagonal for cleanliness np.fill_diagonal(norm_var[a], 0.0) out = dict(cube) out["var"] = fvar_np out["count"] = count_np out["raw_var"] = raw_var_np out["expected"] = expected out["norm_var"] = norm_var out["genomic_distance"] = d1d # Drop large 4D tensor now — no longer needed downstream out.pop("diff", None) return out
def _nanmedian_along_dim(t, dim: int, keepdim: bool = False): """NaN-aware median along a single dim for a torch tensor. Uses ``torch.nanmedian`` (available) but falls back to numpy if on MPS where nanmedian kernels are spotty. """ import torch try: return torch.nanmedian(t, dim=dim, keepdim=keepdim).values except Exception: arr = t.detach().cpu().numpy() med = np.nanmedian(arr, axis=dim, keepdims=keepdim) return torch.as_tensor(med, dtype=t.dtype, device=t.device) # ---------------------------------------------------------------------- # Step 3: axis weights for ACAT # ----------------------------------------------------------------------
[docs] def axis_weight( cd, chrom: Optional[str] = None, device: str = "auto", ) -> np.ndarray: """Compute per-axis weights ``w ∝ 1 / median(trace_variance)``. For each axis we centre every trace at its own mean (bins with NaN excluded) and take the median across traces of each trace's variance. The inverse of that median is the axis weight, normalised to sum 1. Consistent with ArcFISH's ``axis_weight`` routine. """ import torch df = cd.spots.copy() df["x"] = cd.coords[:, 0] df["y"] = cd.coords[:, 1] df["z"] = cd.coords[:, 2] if chrom is not None: df = df[df["chrom"].astype(str) == str(chrom)] if df.empty: raise ValueError("No spots found for the given chromosome filter.") # Build (n_traces, n_bins, 3) cube — reuse the coord cube path if chrom is not None: cube, _, _ = _coord_cube_for_chrom(cd, chrom) else: # All spots; fabricate a dummy chrom to reuse the helper chroms = df["chrom"].unique() all_cubes = [] for c in chroms: sub_cube, _, _ = _coord_cube_for_chrom(cd, c) # Transform to axis-first: (3, T, B) all_cubes.append(sub_cube.transpose(2, 0, 1)) # Stack per-axis, concatenating traces across chromosomes is wrong # (traces are per-chromosome). Compute per-axis median-of-variances # per chrom then average. Simplest approach: operate per-chrom. weights_accum = [] for c in chroms: w_c = axis_weight(cd, chrom=str(c), device=device) weights_accum.append(w_c) return np.mean(np.stack(weights_accum, axis=0), axis=0) dev = get_device(device) dtype = default_float_dtype(dev) c = torch.as_tensor(cube, dtype=dtype, device=dev) # (T, B, 3) valid = ~torch.isnan(c) count = valid.sum(dim=1, keepdim=True).clamp(min=1).to(dtype) # (T, 1, 3) safe = torch.where(valid, c, torch.zeros_like(c)) mean = safe.sum(dim=1, keepdim=True) / count # (T, 1, 3) centred = torch.where(valid, c - mean, torch.zeros_like(c)) # Variance per trace per axis: mean over valid bins trace_var = (centred ** 2).sum(dim=1) / count.squeeze(1) # (T, 3) # Median across traces per_axis_median = trace_var.nanmedian(dim=0).values # (3,) w = 1.0 / per_axis_median.clamp(min=1e-30) w = w / w.sum() return w.detach().cpu().numpy().astype(np.float64)
__all__ = [ "axis_variance_cube", "filter_normalize", "axis_weight", ]