Source code for uchrom.strc.tad.di
# Directionality Index (DI) based TAD detection
# Reference: Dixon et al. (2012) Nature
import torch
import numpy as np
[docs]
def get_device(device='auto'):
"""Get appropriate torch device."""
if device == 'auto':
if torch.cuda.is_available():
return torch.device('cuda')
elif torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
return torch.device(device)
[docs]
def get_dtype(device):
"""MPS only supports float32."""
if device.type == 'mps':
return torch.float32
return torch.float64
[docs]
def calc_directionality_index(contact_mat, window=50, device='auto'):
"""Calculate Directionality Index for each genomic bin."""
dev = get_device(device)
dtype = get_dtype(dev)
mat = contact_mat.to(dtype=dtype, device=dev)
n = mat.shape[0]
di = torch.zeros(n, dtype=dtype, device=dev)
for i in range(window, n - window):
upstream_start = max(0, i - window)
upstream = mat[i, upstream_start:i].sum()
downstream_end = min(n, i + window + 1)
downstream = mat[i, i+1:downstream_end].sum()
a_count = i - upstream_start
b_count = downstream_end - i - 1
avg_a = upstream / a_count if a_count > 0 else 0
avg_b = downstream / b_count if b_count > 0 else 0
# chi-squared-like formula
if avg_a != avg_b and (avg_a + avg_b) > 0:
e = (avg_a + avg_b) / 2
sign = 1 if (avg_b - avg_a) > 0 else -1
di[i] = sign * ((avg_a - e)**2 / e + (avg_b - e)**2 / e)
return di
[docs]
def smooth_with_moving_average(signal, window_size):
"""Apply moving average smoothing via 1D convolution."""
if window_size <= 1:
return signal.clone()
n = len(signal)
device = signal.device
dtype = signal.dtype
padded = torch.nn.functional.pad(
signal.unsqueeze(0).unsqueeze(0),
(window_size // 2, window_size - window_size // 2 - 1),
mode='reflect'
).squeeze()
kernel = torch.ones(window_size, device=device, dtype=dtype) / window_size
smoothed = torch.nn.functional.conv1d(
padded.unsqueeze(0).unsqueeze(0),
kernel.unsqueeze(0).unsqueeze(0)
).squeeze()
return smoothed[:n]
[docs]
def detect_tad_boundaries(di, min_size_frac=0.05):
"""Detect TAD boundaries from sign changes (negative -> positive)."""
n = len(di)
min_size = int(n * min_size_frac)
di_cpu = di.cpu().numpy()
domains = []
start = 0
prev_score = 0.0
for i in range(n):
score = di_cpu[i]
if i == n - 1:
domains.append((start, i))
elif score > 0 and prev_score < 0 and (i - start) >= min_size:
domains.append((start, i))
start = i
prev_score = score
return domains
[docs]
def get_domains(contact_mat, smoothing_param=0.1, min_size_frac=0.05,
window=50, device='auto'):
"""Identify TADs: DI calculation -> smoothing -> boundary detection."""
di = calc_directionality_index(contact_mat, window=window, device=device)
n = len(di)
smoothing_window = max(1, int(n * smoothing_param))
di_smoothed = smooth_with_moving_average(di, smoothing_window)
domains = detect_tad_boundaries(di_smoothed, min_size_frac)
return domains
[docs]
def calc_di_batch(contact_mat, window=50, device='auto'):
"""Batch-optimized DI calculation using matrix operations."""
dev = get_device(device)
dtype = get_dtype(dev)
mat = contact_mat.to(dtype=dtype, device=dev)
n = mat.shape[0]
di = torch.zeros(n, dtype=dtype, device=dev)
for i in range(window, n - window):
upstream = mat[i, i-window:i]
downstream = mat[i, i+1:i+window+1]
a = upstream.sum()
b = downstream.sum()
avg_a = a / window
avg_b = b / window
if avg_a + avg_b > 0 and avg_a != avg_b:
e = (avg_a + avg_b) / 2
sign = torch.sign(avg_b - avg_a)
di[i] = sign * ((avg_a - e)**2 / e + (avg_b - e)**2 / e)
return di