Source code for uchrom.strc.tad.arc_pval

"""ArcFISH-style TAD caller based on intra/inter-domain variance F-test.

Independent implementation of Yu et al. 2025 (*ArcFISH*,
bioRxiv 2025.11.26.690837v1) ``TADCaller.by_pval``.  Not derived from
the GPL-3.0 upstream source.

Algorithm (per chromosome)
--------------------------
1. Build the per-axis normalised variance ``norm_var[c, i, j]`` via
   :func:`uchrom.fea.arc.filter_normalize` (same preprocessing as the
   loop caller).
2. For each candidate boundary position ``b`` with a window of size
   ``window_bp`` (bp), split the window into a left segment
   ``L = [b - w/2, b)`` and a right segment ``R = (b, b + w/2]``.  The
   intra-domain set is the pairs within ``L`` or within ``R``; the
   inter-domain set is the pairs that cross ``b``.
3. Per-axis F-test on the count-weighted intra-/inter-mean variance:
   ``F_c = intra_mean_c / inter_mean_c``.  If ``b`` is a true TAD
   boundary, intra-domain distances are tighter than inter-domain →
   ``F`` is small → left-tail F-CDF gives small p.
4. ACAT Cauchy combine across axes with the ArcFISH axis weights.
5. Run ``scipy.signal.find_peaks`` on the combined statistic to pick
   local minima in p-value (peaks in ``-log10 p``).
6. BH-FDR adjust the per-peak p-values.  Boundaries are peaks with
   ``fdr < fdr_cutoff``.
7. (Optional) **Hierarchical TADs** — repeatedly drop the weakest
   boundary and emit the regions that the remaining boundaries define,
   producing nested levels.

Output
------
A DataFrame with columns
``chrom, start, end, level, score, pval, fdr, boundary_bin_l,
 boundary_bin_r``.  Also written to ``cd.results['tads']`` when
``store=True``.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Tuple

import numpy as np
import pandas as pd

from uchrom.fea.arc import axis_variance_cube, filter_normalize, axis_weight
from uchrom.utils.stats import cauchy_combination


[docs] @dataclass class TADCallerParams: """Runtime parameters for :func:`call_tads_by_pval`. Defaults follow ArcFISH's ``TADCaller(method='pval')``. """ window_bp: float = 1e5 # boundary window size fdr_cutoff: float = 0.1 hierarchical: bool = True # emit nested hierarchy levels max_levels: int = 4 # cap the hierarchy depth prominence: float = 0.0 # pass-through to find_peaks distance: int = 1 # pass-through to find_peaks k_sigma: float = 4.0 frac: float = 0.1
# ---------------------------------------------------------------------- # Helpers # ---------------------------------------------------------------------- def _bh_fdr(pvals: np.ndarray) -> np.ndarray: p = np.asarray(pvals, dtype=np.float64) out = np.full_like(p, np.nan) finite = np.isfinite(p) if not finite.any(): return out pf = p[finite] n = pf.size order = np.argsort(pf) ranked = pf[order] q = ranked * n / (np.arange(n) + 1.0) q = np.minimum.accumulate(q[::-1])[::-1] q = np.clip(q, 0.0, 1.0) o = np.empty_like(pf) o[order] = q out[finite] = o return out def _boundary_pvals( norm_var: np.ndarray, count: np.ndarray, bin_midpoints: np.ndarray, axis_weights: np.ndarray, window_bp: float, ): """Per-bin ACAT p-value for intra-vs-inter domain variance ratio. Returns ``(pvals (B,), per_axis_logp (3, B))``. Bins whose window does not have at least one intra and one inter pair get NaN. """ from scipy.stats import f as _f_dist B = len(bin_midpoints) pvals = np.full(B, np.nan, dtype=np.float64) per_axis_logp = np.full((3, B), 0.0, dtype=np.float64) half_w = 0.5 * window_bp # Count-weighted var and weights per axis — precompute once w_var = np.where(np.isfinite(norm_var), norm_var * count, 0.0) w_cnt = np.where(np.isfinite(norm_var), count, 0.0) for b in range(B): left_mask = (bin_midpoints >= bin_midpoints[b] - half_w) & \ (bin_midpoints < bin_midpoints[b]) right_mask = (bin_midpoints > bin_midpoints[b]) & \ (bin_midpoints <= bin_midpoints[b] + half_w) if left_mask.sum() < 1 or right_mask.sum() < 1: continue li = np.where(left_mask)[0] ri = np.where(right_mask)[0] # Pairs (i, j) where i < j and (both in L) or (both in R) → intra # Pairs (i, j) where i in L and j in R → inter intra_mask = np.zeros((B, B), dtype=bool) inter_mask = np.zeros((B, B), dtype=bool) for i in li: for j in li: if i < j: intra_mask[i, j] = True for i in ri: for j in ri: if i < j: intra_mask[i, j] = True for i in li: for j in ri: if i < j: inter_mask[i, j] = True combined_logp = np.zeros(3) combined_F = np.zeros(3) ok = True for ax in range(3): wv_intra = w_var[ax][intra_mask].sum() wc_intra = w_cnt[ax][intra_mask].sum() wv_inter = w_var[ax][inter_mask].sum() wc_inter = w_cnt[ax][inter_mask].sum() if wc_intra <= 0 or wc_inter <= 0: ok = False break intra_mean = wv_intra / wc_intra inter_mean = wv_inter / wc_inter if intra_mean <= 0 or inter_mean <= 0: ok = False break F = intra_mean / inter_mean combined_F[ax] = F combined_logp[ax] = _f_dist.logcdf( F, dfn=max(wc_intra, 1.0), dfd=max(wc_inter, 1.0) ) if not ok: continue per_axis_logp[:, b] = combined_logp # ACAT combine on p-scale p = np.clip(np.exp(combined_logp), 1e-300, 1 - 1e-15) pvals[b] = float(cauchy_combination(p[None, :], weights=axis_weights, axis=1)[0]) return pvals, per_axis_logp def _emit_tads( boundaries: List[int], bin_ids, chrom: str, level: int, pvals: np.ndarray, fdr: np.ndarray, ) -> List[dict]: """Given sorted boundary bin indices, emit (start, end) TADs.""" if not boundaries: return [] B = len(bin_ids) pts = [0] + sorted(boundaries) + [B] rows = [] for a, b in zip(pts[:-1], pts[1:]): if b - a < 2: continue s = bin_ids[a][0] e = bin_ids[b - 1][1] # Score: min fdr of the enclosing boundaries (worst confidence) adj = [] if a > 0: adj.append(a) if b < B: adj.append(b) if not adj: score, p, q = np.nan, np.nan, np.nan else: qs = [fdr[k] for k in adj if np.isfinite(fdr[k])] ps = [pvals[k] for k in adj if np.isfinite(pvals[k])] q = float(min(qs)) if qs else np.nan p = float(min(ps)) if ps else np.nan score = -np.log10(max(p, 1e-300)) if np.isfinite(p) else np.nan rows.append({ "chrom": chrom, "start": int(s), "end": int(e), "level": int(level), "score": score, "pval": p, "fdr": q, }) return rows # ---------------------------------------------------------------------- # Public API # ----------------------------------------------------------------------
[docs] def call_tads_by_pval( cd, chrom: str, params: Optional[TADCallerParams] = None, device: str = "auto", store: bool = True, result_key: str = "tads", verbose: bool = False, ) -> pd.DataFrame: """ArcFISH-style TAD caller on a single chromosome. Returns a DataFrame of called TADs; stores it in ``cd.results[result_key]`` when ``store=True``. Columns: ``chrom, start, end, level, score, pval, fdr``. Level 1 is the top-level TAD partition; higher levels are nested children when ``params.hierarchical=True``. """ from scipy.signal import find_peaks params = params or TADCallerParams() if verbose: print(f"[tad/{chrom}] variance cube...") cube = axis_variance_cube(cd, chrom=chrom, device=device) if verbose: print(f"[tad/{chrom}] filter + normalise...") cube = filter_normalize(cube, k_sigma=params.k_sigma, frac=params.frac) norm_var = cube["norm_var"] count = cube["count"] bin_ids = cube["bin_ids"] B = len(bin_ids) midpoints = np.array([(s + e) * 0.5 for (s, e) in bin_ids]) w = axis_weight(cd, chrom=chrom, device=device) if verbose: print(f"[tad/{chrom}] axis weights: {w.round(3)}") pvals, per_axis_logp = _boundary_pvals( norm_var, count, midpoints, w, params.window_bp ) # Peak-picking on -log10 p neg_log_p = np.where( np.isfinite(pvals) & (pvals > 0), -np.log10(np.clip(pvals, 1e-300, 1.0)), 0.0, ) peak_idx, _ = find_peaks( neg_log_p, prominence=params.prominence if params.prominence > 0 else None, distance=params.distance if params.distance >= 1 else None, ) if len(peak_idx) == 0: if verbose: print(f"[tad/{chrom}] no peaks found") empty = _empty_tads() if store: _store(cd, result_key, empty) return empty peak_pvals = pvals[peak_idx] peak_fdr = _bh_fdr(peak_pvals) accept = np.where((peak_fdr < params.fdr_cutoff) & np.isfinite(peak_fdr))[0] if len(accept) == 0: if verbose: print(f"[tad/{chrom}] no boundary passes fdr<{params.fdr_cutoff}") empty = _empty_tads() if store: _store(cd, result_key, empty) return empty # Full per-bin fdr array (NaN for non-peaks) fdr_full = np.full(B, np.nan, dtype=np.float64) fdr_full[peak_idx] = peak_fdr accepted_bins = list(peak_idx[accept]) # Sort boundaries by descending strength: strongest first. sorted_by_strength = sorted(accepted_bins, key=lambda i: float(fdr_full[i])) n_bdry = len(sorted_by_strength) # Hierarchy convention: **level 1 = all boundaries** (finest TAD # partition — the natural "full" result). Each subsequent level # drops the weakest remaining boundary, producing progressively # coarser nested TADs. rows = [] if not params.hierarchical: rows.extend(_emit_tads(sorted_by_strength, bin_ids, chrom, 1, pvals, fdr_full)) else: levels = min(params.max_levels, max(1, n_bdry)) for level in range(1, levels + 1): # level 1: keep all n boundaries # level 2: drop the 1 weakest # level k: drop (k-1) weakest n_keep = max(1, n_bdry - (level - 1)) kept = sorted_by_strength[:n_keep] rows.extend(_emit_tads(kept, bin_ids, chrom, level, pvals, fdr_full)) df_out = pd.DataFrame(rows) if df_out.empty: df_out = _empty_tads() if store: _store(cd, result_key, df_out) return df_out
def _empty_tads() -> pd.DataFrame: return pd.DataFrame({ "chrom": pd.Series(dtype=str), "start": pd.Series(dtype="int64"), "end": pd.Series(dtype="int64"), "level": pd.Series(dtype="int64"), "score": pd.Series(dtype="float64"), "pval": pd.Series(dtype="float64"), "fdr": pd.Series(dtype="float64"), }) def _store(cd, key: str, df: pd.DataFrame): if getattr(cd, "results", None) is None: cd.results = {} cd.results[key] = df __all__ = [ "TADCallerParams", "call_tads_by_pval", ]