"""FISHnet — per-allele chromatin-domain caller via modularity maximisation.
Independent reimplementation of Patel et al. 2025
(*FISHnet: detecting chromatin domains in single-cell sequential
Oligopaints imaging data*, Nature Methods 22:1255–1264,
doi:10.1038/s41592-025-02688-1). Not derived from the upstream
repository (``github.com/RohanpatelUpenn/FISHnet``), which ships without
an open-source licence.
Overview
--------
FISHnet is a **per-trace** domain caller — for each allele's pairwise
distance matrix it sweeps a range of distance thresholds, binarises at
each threshold, smooths, and runs Louvain modularity maximisation on the
resulting weighted graph. Thresholds at which the community count is
stable across ``plateau_size`` or more adjacent steps are grouped, and
the consensus partition within each plateau is taken as the final
domain call for that allele.
This complements the other callers in :mod:`uchrom.strc.tad`:
- :func:`get_domains` — classical Dixon directionality index on a
population contact matrix.
- :func:`call_tads_by_pval` — axis-wise F-test on per-bin variance.
- :func:`call_domains_fishnet_trace` — FISHnet on a single allele.
- :func:`call_domains_fishnet` — FISHnet applied across all alleles of a
chromosome, with an ensemble "domain-mask" aggregation.
References
----------
Patel, R., Pham, K., Chandrashekar, H., Phillips-Cremins, J. E. (2025).
FISHnet: detecting chromatin domains in single-cell sequential
Oligopaints imaging data. *Nature Methods* 22, 1255–1264.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
[docs]
@dataclass
class FISHnetParams:
"""Runtime parameters for :func:`call_domains_fishnet_trace`.
Defaults follow Patel et al. 2025.
"""
thresholds: Optional[np.ndarray] = None # in data's distance units
threshold_step: Optional[float] = None # if None, use linspace
threshold_min: Optional[float] = None
threshold_max: Optional[float] = None
plateau_size: int = 4
window_size: int = 2 # (2*w)×(2*w) boxcar
size_exclusion: int = 3 # drop runs < N bins
merge_tol: int = 3 # merge boundaries within N bins
n_louvain_runs: int = 20
resolution: float = 1.0
min_coverage: float = 0.6 # require ≥N% non-NaN bins
max_thresholds: int = 200 # safety cap
impute: bool = True # linear interpolate missing bins
# ----------------------------------------------------------------------
# Low-level helpers
# ----------------------------------------------------------------------
def _linear_impute(dist: np.ndarray) -> np.ndarray:
"""1-D linear interpolation of missing rows/columns of a distance
matrix. Works on the symmetric diagonal strip: for each bin ``i``
with no finite distance to any other bin, fill its row/column with a
linear blend of the nearest finite bins on each side.
"""
d = dist.copy()
n = d.shape[0]
present = np.any(np.isfinite(d), axis=1)
if present.all() or not present.any():
return d
present_idx = np.where(present)[0]
missing_idx = np.where(~present)[0]
for m in missing_idx:
left = present_idx[present_idx < m]
right = present_idx[present_idx > m]
if left.size and right.size:
a = left[-1]
b = right[0]
w = (m - a) / (b - a)
d[m] = (1 - w) * d[a] + w * d[b]
d[:, m] = d[m]
elif left.size:
a = left[-1]
d[m] = d[a]
d[:, m] = d[a]
elif right.size:
b = right[0]
d[m] = d[b]
d[:, m] = d[b]
np.fill_diagonal(d, 0.0)
return d
def _boxcar_nanmean(mat: np.ndarray, window: int) -> np.ndarray:
"""(2*w)×(2*w) NaN-aware boxcar mean. Edge cells use the clipped
window so the output is the same shape as the input.
"""
if window <= 0:
return mat.astype(np.float64, copy=True)
n = mat.shape[0]
out = np.zeros_like(mat, dtype=np.float64)
for i in range(n):
lf, uf = max(0, i - window), min(n, i + window)
for j in range(n):
ls, us = max(0, j - window), min(n, j + window)
block = mat[lf:uf, ls:us]
with np.errstate(invalid="ignore"):
out[i, j] = np.nan if block.size == 0 else np.nanmean(block)
return np.nan_to_num(out, nan=0.0)
def _louvain_labels(
A: np.ndarray,
resolution: float,
seed: Optional[int],
) -> np.ndarray:
"""Run one Louvain pass and return an integer label per node.
Edges with non-positive weight are dropped so the resulting graph is
a weighted undirected network suitable for
:func:`networkx.algorithms.community.louvain_communities`.
"""
import networkx as nx
from networkx.algorithms.community import louvain_communities
n = A.shape[0]
G = nx.Graph()
G.add_nodes_from(range(n))
iu, ju = np.triu_indices(n, k=1)
w = A[iu, ju]
mask = w > 0
edges = list(zip(iu[mask].tolist(), ju[mask].tolist(), w[mask].tolist()))
G.add_weighted_edges_from(edges)
if G.number_of_edges() == 0:
return np.arange(n, dtype=np.int64)
communities = louvain_communities(
G, weight="weight", resolution=resolution, seed=seed
)
labels = np.full(n, -1, dtype=np.int64)
for k, comm in enumerate(communities):
for node in comm:
labels[node] = k
# Un-assigned isolates → unique singleton labels
next_k = int(labels.max()) + 1 if labels.max() >= 0 else 0
for i in np.where(labels < 0)[0]:
labels[i] = next_k
next_k += 1
return labels
def _count_switches(labels: np.ndarray) -> int:
"""Number of runs in the 1-D label sequence (= 1 + label-flip count)."""
if len(labels) == 0:
return 0
return 1 + int(np.sum(labels[1:] != labels[:-1]))
def _consensus_partition(partitions: np.ndarray) -> np.ndarray:
"""Pick the partition with maximal mean adjusted-RAND similarity to
all others. ``partitions`` has shape (n_runs, n_bins).
"""
from sklearn.metrics import adjusted_rand_score
n_runs = partitions.shape[0]
if n_runs == 1:
return partitions[0]
sim = np.zeros((n_runs, n_runs), dtype=np.float64)
for i in range(n_runs):
for j in range(i, n_runs):
s = adjusted_rand_score(partitions[i], partitions[j])
sim[i, j] = sim[j, i] = s
idx = int(np.argmax(sim.mean(axis=0)))
return partitions[idx]
def _find_plateaus(
thresholds: np.ndarray, counts: np.ndarray, min_length: int,
) -> List[np.ndarray]:
"""Return a list of threshold-arrays, one per plateau of
``>= min_length`` consecutive equal community counts.
"""
if min_length <= 0 or len(counts) == 0:
return [thresholds] if len(thresholds) else []
plateaus: List[np.ndarray] = []
i = 0
n = len(counts)
while i < n:
j = i
while j + 1 < n and counts[j + 1] == counts[i]:
j += 1
run_len = j - i + 1
if run_len >= min_length and not np.isnan(counts[i]):
plateaus.append(thresholds[i:j + 1])
i = j + 1
return plateaus
def _size_exclude(labels: np.ndarray, min_size: int) -> np.ndarray:
"""Merge runs shorter than ``min_size`` into the larger neighbour."""
out = labels.astype(np.int64).copy()
n = out.size
if min_size <= 1:
return out
changed = True
while changed:
changed = False
i = 0
while i < n:
j = i
while j + 1 < n and out[j + 1] == out[i]:
j += 1
run_len = j - i + 1
if run_len < min_size:
left = out[i - 1] if i > 0 else None
right = out[j + 1] if j + 1 < n else None
if left is None and right is None:
break
if left is None:
pick = right
elif right is None:
pick = left
else:
# Merge toward the larger neighbouring run
l_len = 0
k = i - 1
while k >= 0 and out[k] == left:
l_len += 1
k -= 1
r_len = 0
k = j + 1
while k < n and out[k] == right:
r_len += 1
k += 1
pick = left if l_len >= r_len else right
out[i:j + 1] = pick
changed = True
break
i = j + 1
return out
def _boundaries_from_labels(labels: np.ndarray) -> List[int]:
"""Indices where the label changes (exclusive end of the left run)."""
if len(labels) == 0:
return []
return [int(i) for i in np.where(labels[1:] != labels[:-1])[0] + 1]
def _merge_close_boundaries(
all_boundaries: List[List[int]], tol: int, n_bins: int,
) -> List[List[int]]:
"""Group boundary positions across plateaus that fall within ``tol``
bins, replacing each group by its integer mean so every plateau
shares canonical boundary coordinates.
"""
flat = sorted({b for lst in all_boundaries for b in lst if 0 < b < n_bins})
if not flat:
return [list(lst) for lst in all_boundaries]
groups: List[List[int]] = [[flat[0]]]
for v in flat[1:]:
if v - groups[-1][-1] <= tol:
groups[-1].append(v)
else:
groups.append([v])
# Map each raw boundary to its group's canonical value
lookup = {}
for g in groups:
canon = int(round(float(np.mean(g))))
for v in g:
lookup[v] = canon
merged = []
for lst in all_boundaries:
remapped = sorted({lookup.get(v, v) for v in lst if 0 < v < n_bins})
merged.append(remapped)
return merged
def _domains_from_boundaries(
boundaries: List[int], n_bins: int,
) -> List[Tuple[int, int]]:
"""``boundaries`` are bin indices where a new domain starts. Returns
``[(start, end_exclusive), ...]`` covering ``[0, n_bins)``.
"""
pts = [0] + sorted(set(boundaries)) + [n_bins]
out = []
for a, b in zip(pts[:-1], pts[1:]):
if b - a >= 1:
out.append((a, b))
return out
# ----------------------------------------------------------------------
# Per-trace FISHnet
# ----------------------------------------------------------------------
[docs]
def call_domains_fishnet_trace(
dist: np.ndarray,
params: Optional[FISHnetParams] = None,
seed_base: int = 0,
) -> dict:
"""Run FISHnet on a single-allele pairwise distance matrix.
Parameters
----------
dist : ndarray (n_bins, n_bins)
Symmetric pairwise distance matrix, NaN for missing spots, in nm.
params : FISHnetParams, optional
Hyperparameters. Use defaults if ``None``.
seed_base : int
Base seed for the 20 Louvain runs (reproducibility).
Returns
-------
dict with keys
``domains`` : list of (start, end_exclusive) bin indices per plateau
``boundaries`` : sorted list of canonical boundary bin indices
``domain_mask`` : (n_bins, n_bins) int — count of plateaus in which
bins i and j fall in the same domain (0 if never).
``n_plateaus`` : number of plateaus detected
``coverage`` : fraction of bins with ≥1 non-NaN distance
"""
params = params or FISHnetParams()
dist = np.asarray(dist, dtype=np.float64)
n_bins = dist.shape[0]
empty = {
"domains": [],
"boundaries": [],
"domain_mask": np.zeros((n_bins, n_bins), dtype=np.int64),
"n_plateaus": 0,
"coverage": 0.0,
}
# Coverage check (a bin counts as present if any pairwise distance to
# another bin is finite — i.e. the bin's 3D coordinate was detected)
any_non_nan = np.any(np.isfinite(dist), axis=1)
coverage = float(any_non_nan.mean())
if coverage < params.min_coverage:
empty["coverage"] = coverage
return empty
if params.impute and coverage < 1.0:
dist = _linear_impute(dist)
# Threshold sweep
if params.thresholds is not None:
thresholds = np.asarray(params.thresholds, dtype=np.float64)
else:
vmin = params.threshold_min
vmax = params.threshold_max
if vmin is None:
pos = dist[np.isfinite(dist) & (dist > 0)]
vmin = float(np.min(pos)) if pos.size else 0.0
if vmax is None:
vmax = float(np.nanmax(dist))
if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax <= vmin:
return {**empty, "coverage": coverage}
if params.threshold_step is not None:
thresholds = np.arange(vmin, vmax + 1e-9, params.threshold_step)
else:
thresholds = np.linspace(vmin, vmax, params.max_thresholds)
if len(thresholds) > params.max_thresholds:
thresholds = np.linspace(vmin, vmax, params.max_thresholds)
# Per-threshold: run Louvain 20× and cache both the partitions and the
# stable-community-count statistic — phase 2 re-uses the same partitions.
n_runs = params.n_louvain_runs
partitions = np.full(
(len(thresholds), n_runs, n_bins), -1, dtype=np.int64,
)
counts = np.full(len(thresholds), np.nan, dtype=np.float64)
for ti, thr in enumerate(thresholds):
B = (dist < thr).astype(np.float64)
B[~np.isfinite(dist)] = 0.0
A = _boxcar_nanmean(B, params.window_size)
run_counts = []
for r in range(n_runs):
labels = _louvain_labels(
A, params.resolution, seed=seed_base + ti * 1000 + r,
)
partitions[ti, r] = labels
k = int(len(np.unique(labels)))
# Reject spurious 2-community partitions with many interleaved
# flips — per the FISHnet paper, genuine domain partitions
# should produce contiguous runs in the 1-D bin order.
if k == 2 and _count_switches(labels) > 2:
continue
run_counts.append(k)
if run_counts:
counts[ti] = round(float(np.mean(run_counts)))
plateaus = _find_plateaus(thresholds, counts, params.plateau_size)
if not plateaus:
return {**empty, "coverage": coverage}
# Per-plateau consensus partition — reuse cached partitions per threshold.
t_to_idx = {float(t): i for i, t in enumerate(thresholds)}
all_boundaries: List[List[int]] = []
all_labels: List[np.ndarray] = []
for plat_thresholds in plateaus:
parts = np.zeros((len(plat_thresholds), n_bins), dtype=np.int64)
for ki, thr in enumerate(plat_thresholds):
ti = t_to_idx[float(thr)]
parts[ki] = _consensus_partition(partitions[ti])
consensus = _consensus_partition(parts)
consensus = _size_exclude(consensus, params.size_exclusion)
all_labels.append(consensus)
all_boundaries.append(_boundaries_from_labels(consensus))
all_boundaries = _merge_close_boundaries(
all_boundaries, params.merge_tol, n_bins,
)
# Ensemble domain mask: for each plateau, bins i,j → 1 if same domain
mask = np.zeros((n_bins, n_bins), dtype=np.int64)
for labels in all_labels:
same = labels[:, None] == labels[None, :]
mask += same.astype(np.int64)
boundaries_union = sorted({b for bl in all_boundaries for b in bl})
domains = _domains_from_boundaries(boundaries_union, n_bins)
return {
"domains": domains,
"boundaries": boundaries_union,
"domain_mask": mask,
"n_plateaus": len(plateaus),
"coverage": coverage,
}
# ----------------------------------------------------------------------
# ChromData-level wrapper
# ----------------------------------------------------------------------
[docs]
def call_domains_fishnet(
cd,
chrom: str,
params: Optional[FISHnetParams] = None,
store: bool = True,
result_key: str = "fishnet_domains",
ensemble_key: str = "fishnet_ensemble",
verbose: bool = False,
max_traces: Optional[int] = None,
) -> pd.DataFrame:
"""Run FISHnet on every trace of ``chrom`` in ``cd``.
Parameters
----------
cd : ChromData
chrom : str
params : FISHnetParams, optional
store : bool
If True, write results to ``cd.results[result_key]`` and the
ensemble domain mask to ``cd.results[ensemble_key]``.
max_traces : int, optional
Hard cap on the number of traces processed (useful for quick
previews — FISHnet is O(n_bins² · n_thresholds) per trace).
Returns
-------
DataFrame with one row per (trace, domain):
``chrom, trace_id, domain_idx, start, end, start_bin, end_bin``.
"""
from uchrom.fea.distance import _bin_coord_cube, _pairwise_distance_per_trace
params = params or FISHnetParams()
df = cd.to_dataframe()
cube, bin_ids, trace_ids = _bin_coord_cube(df, chrom=chrom)
dists = _pairwise_distance_per_trace(cube) # (n_traces, n_bins, n_bins)
n_traces, n_bins, _ = dists.shape
if max_traces is not None:
n_traces = min(n_traces, max_traces)
if verbose:
print(f"[fishnet/{chrom}] running on {n_traces} traces × "
f"{n_bins} bins; plateau_size={params.plateau_size}, "
f"window={params.window_size}")
rows = []
ensemble_mask = np.zeros((n_bins, n_bins), dtype=np.int64)
skipped = 0
for ti in range(n_traces):
res = call_domains_fishnet_trace(
dists[ti], params=params, seed_base=ti,
)
ensemble_mask += res["domain_mask"]
if not res["domains"]:
skipped += 1
continue
for di, (a, b) in enumerate(res["domains"]):
s_bp = int(bin_ids[a][0])
e_bp = int(bin_ids[b - 1][1])
rows.append({
"chrom": chrom,
"trace_id": trace_ids[ti],
"domain_idx": di,
"start": s_bp,
"end": e_bp,
"start_bin": int(a),
"end_bin": int(b),
})
if verbose and (ti + 1) % 25 == 0:
print(f"[fishnet/{chrom}] {ti + 1}/{n_traces} traces")
out = pd.DataFrame(rows, columns=[
"chrom", "trace_id", "domain_idx", "start", "end",
"start_bin", "end_bin",
])
if verbose:
print(f"[fishnet/{chrom}] called domains in "
f"{n_traces - skipped}/{n_traces} traces; "
f"{len(out)} domains total")
if store:
if getattr(cd, "results", None) is None:
cd.results = {}
cd.results[result_key] = out
cd.results[ensemble_key] = {
"chrom": chrom,
"mask": ensemble_mask,
"bin_ids": bin_ids,
"n_traces": n_traces,
}
return out
__all__ = [
"FISHnetParams",
"call_domains_fishnet_trace",
"call_domains_fishnet",
]