Source code for uchrom.strc.loop.axiswise_f

"""ArcFISH-style axis-wise F-test loop caller (``AxisWiseF`` + ``LoopCaller``).

Independent implementation of the algorithm described in Yu et al. 2025
(ArcFISH); not derived from the GPL-3.0 ArcFISH source.

Pipeline
--------
1. Per-axis variance + counts → :func:`uchrom.fea.arc.axis_variance_cube`.
2. LOWESS-normalised variance → :func:`uchrom.fea.arc.filter_normalize`.
3. Axis weights for ACAT → :func:`uchrom.fea.arc.axis_weight`.
4. For each candidate bin pair ``(i, j)`` within ``[cut_lo, cut_up]`` 1D
   genomic distance:
     a. Compute a local background mean of normalised variance over a
        ring ``[inner_cut, outer_cut]`` around ``(i, j)``, *excluding* its
        own row/column.
     b. Per-axis F-statistic ``F_c = var_c[i,j] / denom_c``; left-tailed
        F-CDF gives ``p_c`` (loops have below-average variance).
     c. ACAT combine per-axis p-values with axis weights → per-pair p.
5. BH-FDR adjust; accept candidates with ``fdr < fdr_cutoff``.
6. Cluster accepted candidates within ``gap`` 1D distance; pick the
   lowest-p entry in each cluster as the "summit".
7. Optionally filter summits by raw ``pval_cutoff`` and minimum cluster
   size.

Result columns
--------------
``chrom1, start1, end1, chrom2, start2, end2, score (-log10 p), pval, fdr,
 summit_i, summit_j, cluster_size``
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd

from uchrom.fea.arc import axis_variance_cube, filter_normalize, axis_weight
from uchrom.fea.arc import _pairwise_diff_tensor
from uchrom.utils.stats import cauchy_combination


# ----------------------------------------------------------------------
# Parameters
# ----------------------------------------------------------------------

[docs] @dataclass class LoopCallerParams: """Runtime parameters for :func:`call_loops_axiswise_f`. Defaults mirror ArcFISH's ``LoopCaller``. """ cut_lo: float = 1e5 # min 1D size of a candidate loop cut_up: float = 1e6 # max 1D size of a candidate loop inner_cut: float = 2.5e4 # inner ring radius for local background outer_cut: float = 5e4 # outer ring radius for local background fdr_cutoff: float = 0.1 pval_cutoff: float = 1e-5 gap: float = 5e4 # 1D gap for clustering summit neighbours k_sigma: float = 4.0 # outlier k for LOWESS-based filter frac: float = 0.1 # LOWESS span min_cluster_size: int = 1 def __post_init__(self): assert self.cut_lo < self.cut_up assert self.inner_cut < self.outer_cut
# ---------------------------------------------------------------------- # Helpers # ---------------------------------------------------------------------- def _bh_fdr(pvals: np.ndarray) -> np.ndarray: """Benjamini–Hochberg FDR on a flat array of p-values. NaNs pass through as NaNs. Returns q-values the same shape as input. """ p = np.asarray(pvals, dtype=np.float64) out = np.full_like(p, np.nan) finite = np.isfinite(p) if not finite.any(): return out pf = p[finite] n = pf.size order = np.argsort(pf) ranked = pf[order] q = ranked * n / (np.arange(n) + 1.0) # Enforce monotonicity (reverse cumulative min) q = np.minimum.accumulate(q[::-1])[::-1] q = np.clip(q, 0.0, 1.0) out_finite = np.empty_like(pf) out_finite[order] = q out[finite] = out_finite return out def _ring_mask(B: int, inner_bins: int, outer_bins: int): """Boolean ``(B, B, B, B)`` ring mask for per-pair backgrounds. ``mask[i, j, a, b]`` is True if ``(a, b)`` lies in the 1D-distance ring of ``(i, j)`` excluding row ``i``, column ``j``, and ``(a, b) == (i, j)``. This can be huge — we instead compute the background mean on the fly rather than materialising a 4D mask. """ raise NotImplementedError("Handled inline in _background_mean_axis.") def _background_ring( var_axis: np.ndarray, count_axis: np.ndarray, i: int, j: int, inner_bins: int, outer_bins: int, ): """Return ``(denom, total_count)`` for the ring around ``(i, j)``. ``denom`` is the count-weighted mean of ``var_axis`` over the ring ``[inner_bins, outer_bins]`` around ``(i, j)``, excluding row ``i`` and column ``j``. ``total_count`` is the sum of counts in the same ring — used as the ``df2`` for the F-test. NaN-safe. """ B = var_axis.shape[0] lo_i = max(0, i - outer_bins) hi_i = min(B, i + outer_bins + 1) in_lo_i = max(0, i - inner_bins) in_hi_i = min(B, i + inner_bins + 1) lo_j = max(0, j - outer_bins) hi_j = min(B, j + outer_bins + 1) in_lo_j = max(0, j - inner_bins) in_hi_j = min(B, j + inner_bins + 1) # Build a boolean mask of kept entries then sum (avoids the sum-of- # rectangles book-keeping and matches ArcFISH's `ij_background` logic). mask = np.zeros((B, B), dtype=bool) mask[lo_i:hi_i, lo_j:hi_j] = True mask[in_lo_i:in_hi_i, in_lo_j:in_hi_j] = False mask[i, :] = False mask[:, j] = False vals = var_axis[mask] cnts = count_axis[mask] finite = np.isfinite(vals) & (cnts > 0) if not finite.any(): return np.nan, 0 v = vals[finite] c = cnts[finite].astype(np.float64) total = c.sum() denom = float((c * v).sum() / total) return denom, float(total) def _cluster_summits( triu_hits: np.ndarray, pvals: np.ndarray, d1d: np.ndarray, gap: float, ): """Cluster accepted candidates by 1D proximity; pick lowest-p summits. Parameters ---------- triu_hits : bool array (n_hit,) Already filtered upper-triangle candidates. pvals : float array (n_hit,) Raw p-values for those candidates. d1d : float array (n_hit, n_hit) Pairwise 1D genomic distance between candidate centres in bp, computed upstream. gap : float Max separation for two candidates to be in the same cluster. Returns ------- clusters : list[list[int]] Index lists per cluster (into the ``triu_hits`` / ``pvals`` arrays). summits : list[int] Index into ``pvals`` of the lowest-p entry per cluster. """ n = pvals.size if n == 0: return [], [] # Union-Find via simple BFS on adjacency where nodes ≤ gap apart adj = [set() for _ in range(n)] for i in range(n): for j in range(i + 1, n): if d1d[i, j] <= gap: adj[i].add(j) adj[j].add(i) seen = [False] * n clusters = [] for i in range(n): if seen[i]: continue stack = [i] comp = [] while stack: k = stack.pop() if seen[k]: continue seen[k] = True comp.append(k) stack.extend(adj[k]) clusters.append(comp) summits = [] for comp in clusters: best = comp[int(np.argmin(pvals[comp]))] summits.append(best) return clusters, summits # ---------------------------------------------------------------------- # Public API # ----------------------------------------------------------------------
[docs] def call_loops_axiswise_f( cd, chrom: str, params: Optional[LoopCallerParams] = None, device: str = "auto", store: bool = True, result_key: str = "loops", verbose: bool = False, ) -> pd.DataFrame: """ArcFISH-style loop caller on a single chromosome. Parameters ---------- cd : :class:`uchrom.ChromData` Tracing data. Must contain ``trace_id`` in ``spots`` and positive coordinates. chrom : str Chromosome name — matched against ``cd.spots['chrom']``. params : LoopCallerParams, optional Override any default. device : str ``'auto' | 'cpu' | 'cuda' | 'mps'`` — passed to GPU-friendly preprocessing. store : bool If True, also write the result DataFrame into ``cd.results[result_key]``. result_key : str Key under which to store the result (default ``'loops'``). Returns ------- DataFrame One row per called loop summit. """ params = params or LoopCallerParams() from scipy.stats import f as _f_dist # Step 1 + 2: variances + LOWESS normalisation if verbose: print(f"[loop/{chrom}] computing axis variance cube...") cube = axis_variance_cube(cd, chrom=chrom, device=device) if verbose: print(f"[loop/{chrom}] {cube['n_traces']} traces, " f"{len(cube['bin_ids'])} bins; normalising...") cube = filter_normalize(cube, k_sigma=params.k_sigma, frac=params.frac) bin_ids = cube["bin_ids"] B = len(bin_ids) norm_var = cube["norm_var"] # (3, B, B) count = cube["count"] # (3, B, B) d1d = cube["genomic_distance"] # (B, B) # Step 3: axis weights w = axis_weight(cd, chrom=chrom, device=device) if verbose: print(f"[loop/{chrom}] axis weights x={w[0]:.3f}, y={w[1]:.3f}, z={w[2]:.3f}") # Figure out integer bin widths for the background ring med_bin_size = float(np.median([e - s for s, e in bin_ids])) inner_bins = max(1, int(round(params.inner_cut / med_bin_size))) outer_bins = max(inner_bins + 1, int(round(params.outer_cut / med_bin_size))) if verbose: print(f"[loop/{chrom}] median bin size {med_bin_size:.0f} bp, " f"ring bins [{inner_bins}, {outer_bins}]") # Candidate mask: upper triangle with 1D distance in [cut_lo, cut_up] cand = (d1d >= params.cut_lo) & (d1d <= params.cut_up) cand = np.triu(cand, k=1) cand_idx = np.argwhere(cand) if cand_idx.size == 0: if verbose: print(f"[loop/{chrom}] no candidates in [{params.cut_lo}, {params.cut_up}]") empty = _empty_result() if store: _store_result(cd, result_key, empty) return empty # Step 4: per-axis F-test if verbose: print(f"[loop/{chrom}] {len(cand_idx)} candidates; running F-test...") # Work in log-p space to avoid floor-clipping ties at the minimum p. # scipy's F.logcdf handles extreme F values without underflow, so a # loop with F ≈ 1e-4 and a non-loop pair with F ≈ 0.2 stay # distinguishable after ACAT combination, which matters for summit # tie-breaking later. per_axis_logp = np.full((3, len(cand_idx)), 0.0, dtype=np.float64) per_axis_F = np.full((3, len(cand_idx)), np.nan, dtype=np.float64) for k, (i, j) in enumerate(cand_idx): for ax in range(3): vij = norm_var[ax, i, j] cij = count[ax, i, j] if not np.isfinite(vij) or cij <= 0: continue denom, bg_count = _background_ring( norm_var[ax], count[ax], int(i), int(j), inner_bins, outer_bins, ) if not np.isfinite(denom) or denom <= 0 or bg_count <= 3: continue F = vij / denom logp = _f_dist.logcdf(F, dfn=max(int(cij), 1), dfd=max(bg_count, 1)) per_axis_F[ax, k] = F per_axis_logp[ax, k] = logp # Convert log-p back to p for ACAT (clip only here; ACAT needs raw p). # A small floor (1e-300) still limits the dynamic range but that is # enough precision for ACAT — the F-based ``score`` below breaks any # remaining ties when picking the summit of a cluster. per_axis_logp = np.where(np.isfinite(per_axis_logp), per_axis_logp, 0.0) per_axis_p = np.clip(np.exp(per_axis_logp), 1e-300, 1 - 1e-15) # Step 4c: ACAT combine combined = cauchy_combination(per_axis_p.T, weights=w, axis=1) # cauchy_combination returns upper tail (1 - CDF); for this left-tail # setup that's our raw p-value. pvals = np.clip(np.asarray(combined), 1e-300, 1.0) # Step 5: BH-FDR fdr = _bh_fdr(pvals) accept = (fdr < params.fdr_cutoff) if not accept.any(): if verbose: print(f"[loop/{chrom}] no candidates pass FDR < {params.fdr_cutoff}") empty = _empty_result() if store: _store_result(cd, result_key, empty) return empty # Step 6: cluster accepted hits by 1D distance between midpoints. # We order candidates inside a cluster by a full-precision log-p # "score" (weighted sum of per-axis log-p) so that p-value ties at # the numerical floor are broken by the actual F-statistic strength. hits_idx = cand_idx[accept] hit_pvals = pvals[accept] w_arr = np.asarray(w).reshape(3, 1) hit_score = (per_axis_logp * w_arr).sum(axis=0)[accept] # smaller = better mids_i = np.array([(bin_ids[i][0] + bin_ids[i][1]) * 0.5 for i in hits_idx[:, 0]]) mids_j = np.array([(bin_ids[j][0] + bin_ids[j][1]) * 0.5 for j in hits_idx[:, 1]]) di = np.abs(mids_i[:, None] - mids_i[None, :]) dj = np.abs(mids_j[:, None] - mids_j[None, :]) d1d_hits = np.maximum(di, dj) clusters, summits = _cluster_summits(accept, hit_score, d1d_hits, params.gap) # Step 7: pseudo-contact frequency filter # Build the (n_traces, n_bins, n_bins) 3D pdist matrix. Uses the same # coord cube as filter_normalize — compute once, store on CPU. freq_mat = _pseudo_contact_frequency(cd, chrom, bin_ids, device=device) # Step 8: final filter — p-value + pseudo-contact frequency rows = [] accept_idx_in_cand = np.where(accept)[0] for comp, summit in zip(clusters, summits): if len(comp) < params.min_cluster_size: continue i_b, j_b = hits_idx[summit] p_summit = hit_pvals[summit] if p_summit >= params.pval_cutoff: continue # Pseudo-contact frequency rule (ArcFISH "summit"): # singleton cluster (== 1 triu entry) → freq > 1/2 # larger cluster → freq > 1/3 is_singleton = len(comp) == 1 freq_thresh = 0.5 if is_singleton else (1.0 / 3.0) f = freq_mat[int(i_b), int(j_b)] if not (np.isfinite(f) and f > freq_thresh): continue s1, e1 = bin_ids[i_b] s2, e2 = bin_ids[j_b] rows.append({ "chrom1": chrom, "start1": int(s1), "end1": int(e1), "chrom2": chrom, "start2": int(s2), "end2": int(e2), "score": float(-np.log10(max(p_summit, 1e-300))), "pval": float(p_summit), "fdr": float(fdr[accept_idx_in_cand[summit]]), "cluster_size": int(len(comp)), "contact_freq": float(f), "summit_i": int(i_b), "summit_j": int(j_b), }) df_out = pd.DataFrame(rows) if df_out.empty: df_out = _empty_result() if store: _store_result(cd, result_key, df_out) return df_out
def _store_result(cd, key: str, df: pd.DataFrame): """Assign ``df`` into ``cd.results[key]``, creating ``cd.results`` if needed.""" results = getattr(cd, "results", None) if results is None: cd.results = {} results = cd.results results[key] = df def _pseudo_contact_frequency(cd, chrom, bin_ids, device="auto"): """Fraction of traces whose 3D pairwise distance falls under a data-driven "contact" cutoff. Replicates ArcFISH's ``AxisWiseF.append_summit`` frequency rule: the cutoff is the NaN-mean pairwise distance at 1D separation equal to one bin width; ``freq_mat[i, j]`` is the fraction of traces where the (i, j) 3D distance is below that cutoff, conditional on both endpoints being observed. """ import torch diff, _, _ = _pairwise_diff_tensor(cd, chrom, device) # (3, T, B, B) → (T, B, B) 3D pdist pdist = torch.sqrt((diff ** 2).sum(dim=0)) pdist_np = pdist.detach().cpu().numpy().astype(np.float64) B = len(bin_ids) if pdist_np.shape[1] != B: return np.ones((B, B)) # defensive # 1D distance between bins in bp mids = np.array([(s + e) * 0.5 for (s, e) in bin_ids]) d1d = np.abs(mids[:, None] - mids[None, :]) med_bin_size = float(np.median([e - s for s, e in bin_ids])) # ArcFISH uses |d1d[i] - d1d[j]| == 25e3 (one bin). Generalise to # a small tolerance around the median bin size. tol = max(1.0, 0.1 * med_bin_size) d1sel = np.abs(d1d - med_bin_size) <= tol if d1sel.sum() == 0: return np.ones((B, B)) cutoff = float(np.nanmean(pdist_np[:, d1sel])) if not np.isfinite(cutoff) or cutoff <= 0: return np.ones((B, B)) with np.errstate(invalid="ignore"): n_contact = np.nansum(pdist_np < cutoff, axis=0) n_valid = np.sum(~np.isnan(pdist_np), axis=0) freq = np.where(n_valid > 0, n_contact / n_valid, 0.0) return freq.astype(np.float64) def _empty_result() -> pd.DataFrame: return pd.DataFrame({ "chrom1": pd.Series(dtype=str), "start1": pd.Series(dtype="int64"), "end1": pd.Series(dtype="int64"), "chrom2": pd.Series(dtype=str), "start2": pd.Series(dtype="int64"), "end2": pd.Series(dtype="int64"), "score": pd.Series(dtype="float64"), "pval": pd.Series(dtype="float64"), "fdr": pd.Series(dtype="float64"), "cluster_size": pd.Series(dtype="int64"), "summit_i": pd.Series(dtype="int64"), "summit_j": pd.Series(dtype="int64"), }) __all__ = [ "LoopCallerParams", "call_loops_axiswise_f", ]