"""Agent-readable discovery schema for :class:`uchrom.core.ChromData`.
The schema is intentionally stored inside ``cd.uns`` as a JSON payload so
it round-trips through ``.h5cd`` without requiring a new HDF5 layout.
"""
from __future__ import annotations
import hashlib
import json
from datetime import datetime, timezone
from typing import Any, Mapping
import numpy as np
import pandas as pd
DISCOVERY_SCHEMA_KEY = "auto_discovery_schema"
DISCOVERY_SCHEMA_VERSION = "0.1"
DISCOVERY_SCHEMA_FORMAT = "uchrom.auto_discovery.schema+json"
[docs]
def build_discovery_schema(
cdata,
*,
dataset_name: str | None = None,
include_linked_adata: bool = True,
max_catalog_items: int = 500,
) -> dict[str, Any]:
"""Build an agent-readable schema from a ``ChromData`` object.
The returned dict is JSON-serializable and can be persisted via
:func:`pack_schema` under ``cd.uns['auto_discovery_schema']``.
"""
fields = {
"coords": {
"axis": "spot",
"shape": _shape(cdata.coords),
"dtype": str(getattr(cdata.coords, "dtype", "")),
"description": "Spot-level 3D coordinates.",
},
"spots": _dataframe_field(cdata.spots, axis="spot"),
"tracks": _dataframe_field(cdata.tracks, axis="spot"),
"cells": _dataframe_field(cdata.cells, axis="cell"),
"traces": _dataframe_field(cdata.traces, axis="trace"),
"cellm": _array_mapping_field(cdata.cellm, axis="cell"),
"layers": _array_mapping_field(cdata.layers, axis="spot"),
"results": {"keys": sorted(map(str, cdata.results.keys()))},
"uns": {
"keys": sorted(k for k in map(str, cdata.uns.keys()) if k != DISCOVERY_SCHEMA_KEY)
},
}
chroms = _as_str_list(getattr(cdata, "chroms", []))
cell_type_counts = _cell_type_counts(cdata)
linked = _linked_anndata_field(
cdata,
include_linked_adata=include_linked_adata,
max_catalog_items=max_catalog_items,
)
tracks = _columns(cdata.tracks)
genes = linked.get("var_names", {"n": 0, "values": [], "truncated": False, "sha1": ""})
schema = {
"schema_version": DISCOVERY_SCHEMA_VERSION,
"schema_type": "uchrom_multiomics_auto_discovery",
"created_utc": datetime.now(timezone.utc).replace(microsecond=0).isoformat(),
"dataset": {
"name": dataset_name
or str(cdata.uns.get("source") or cdata.uns.get("dataset") or "ChromData"),
"genome_assembly": _json_scalar(cdata.uns.get("genome_assembly")),
"xyz_unit": _json_scalar(cdata.uns.get("xyz_unit")),
},
"summary": {
"n_spots": int(cdata.n_spots),
"n_traces": int(cdata.n_traces),
"n_cells": int(cdata.n_cells),
"n_chroms": len(chroms),
"n_tracks": len(tracks),
},
"axes": {
"spot": "Rows of coords/spots/tracks; one observed genomic bin in one trace.",
"trace": "A chromatin fiber or allele-specific chromosome trace.",
"cell": "Cell-level metadata, embeddings, and linked RNA observations.",
"gene": "Variables in linked_adata when RNA expression is available.",
"marker": "Columns in tracks when per-spot IF/RNA marker scores are available.",
},
"modalities": {
"chromatin_tracing": {
"present": True,
"fields": ["coords", "spots", "traces"],
"operations": [
"chromosome_subset",
"cell_subset",
"trace_subset",
"pairwise_3d_distance",
"intra_chromatin_distance",
"inter_chromatin_distance",
],
},
"if_tracks": {
"present": bool(tracks),
"fields": ["tracks"],
"operations": [
"marker_high_low_bin_selection",
"marker_stratified_distance",
"per_cell_marker_summary",
"per_cell_type_marker_summary",
],
},
"cell_metadata": {
"present": len(cdata.cells) > 0,
"fields": ["cells", "cellm"],
"operations": ["cell_type_stratification", "embedding_visualization"],
},
"rna_expression": {
"present": bool(linked.get("present")),
"fields": ["linked_adata"],
"operations": [
"gene_expression_lookup",
"expression_stratification",
"gene_marker_correlation",
"chromatin_expression_association",
],
},
},
"fields": fields,
"linked_adata": linked,
"catalogs": {
"chroms": _catalog(chroms, max_catalog_items),
"tracks": _catalog(tracks, max_catalog_items),
"cell_types": {
"n": len(cell_type_counts),
"counts": cell_type_counts,
"values": list(cell_type_counts.keys()),
},
"genes": genes,
},
"constraints": {
"cell_axis_alignment": (
"cells.index, cellm arrays, and linked_adata.obs_names should share "
"the same cell_id order when linked_adata is present."
),
"spot_axis_alignment": "coords, spots, and tracks are row-aligned on the spot axis.",
"free_code_policy": (
"Discovery code agents may write new Python code, but exploratory "
"runs should be recorded in notebooks with data checks and verification."
),
},
"known_missing": _known_missing(cdata, linked),
"recommended_verification": [
"required_fields_exist",
"minimum_cell_count",
"minimum_spot_or_trace_count",
"finite_numeric_output",
"statistical_hypothesis_test",
"runtime_under_budget",
"deterministic_rerun",
"negative_control_or_permutation",
"redundancy_against_existing_parameters",
],
}
schema["schema_hash"] = _hash_json(schema)
return schema
[docs]
def pack_schema(schema: Mapping[str, Any]) -> dict[str, str]:
"""Pack a schema as an HDF5-friendly ``uns`` entry."""
payload = json.dumps(schema, sort_keys=True, separators=(",", ":"), default=_json_default)
return {
"format": DISCOVERY_SCHEMA_FORMAT,
"version": str(schema.get("schema_version", DISCOVERY_SCHEMA_VERSION)),
"payload": payload,
}
[docs]
def unpack_schema(raw: Any) -> dict[str, Any]:
"""Unpack a schema from ``cd.uns['auto_discovery_schema']``."""
if raw is None:
return {}
if isinstance(raw, Mapping):
payload = raw.get("payload")
if payload is not None:
if isinstance(payload, bytes):
payload = payload.decode("utf-8")
return json.loads(str(payload))
return _json_roundtrip(raw)
if isinstance(raw, bytes):
raw = raw.decode("utf-8")
if isinstance(raw, str):
return json.loads(raw)
return _json_roundtrip(raw)
[docs]
def validate_discovery_schema(schema: Mapping[str, Any], cdata=None) -> list[str]:
"""Return validation issues for a discovery schema."""
issues: list[str] = []
for key in ("schema_version", "schema_type", "summary", "modalities", "fields", "catalogs"):
if key not in schema:
issues.append(f"missing top-level key: {key}")
if schema.get("schema_type") != "uchrom_multiomics_auto_discovery":
issues.append("schema_type is not uchrom_multiomics_auto_discovery")
if cdata is not None and "summary" in schema:
summary = schema["summary"]
expected = {
"n_spots": int(cdata.n_spots),
"n_traces": int(cdata.n_traces),
"n_cells": int(cdata.n_cells),
}
for key, value in expected.items():
if int(summary.get(key, -1)) != value:
issues.append(f"summary.{key}={summary.get(key)!r} != {value}")
return issues
[docs]
def schema_to_agent_context(schema: Mapping[str, Any], *, max_items: int = 40) -> str:
"""Render a compact, prompt-ready schema summary."""
dataset = schema.get("dataset", {})
summary = schema.get("summary", {})
catalogs = schema.get("catalogs", {})
linked = schema.get("linked_adata", {})
lines = [
"# ChromData discovery schema",
"",
f"dataset: {dataset.get('name', 'ChromData')}",
f"genome: {dataset.get('genome_assembly') or 'unknown'}",
f"xyz_unit: {dataset.get('xyz_unit') or 'unknown'}",
(
"shape: "
f"{summary.get('n_spots', 0)} spots, "
f"{summary.get('n_traces', 0)} traces, "
f"{summary.get('n_cells', 0)} cells"
),
"",
"modalities:",
]
for name, info in schema.get("modalities", {}).items():
status = "present" if info.get("present") else "missing"
ops = ", ".join(info.get("operations", [])[:6])
lines.append(f"- {name}: {status}; operations: {ops}")
lines.extend([
"",
f"chroms: {_format_catalog(catalogs.get('chroms', {}), max_items)}",
f"cell_types: {_format_cell_types(catalogs.get('cell_types', {}), max_items)}",
f"tracks: {_format_catalog(catalogs.get('tracks', {}), max_items)}",
])
if linked.get("present"):
lines.append(f"linked_adata: shape={linked.get('shape')}, X={linked.get('x_type')}")
lines.append(f"genes: {_format_catalog(catalogs.get('genes', {}), max_items)}")
else:
lines.append("linked_adata: missing")
missing = schema.get("known_missing", [])
if missing:
lines.extend(["", "known_missing:"])
for item in missing:
lines.append(f"- {item}")
lines.extend(["", "verification_required:"])
for item in schema.get("recommended_verification", []):
lines.append(f"- {item}")
return "\n".join(lines)
def _dataframe_field(df: pd.DataFrame | None, *, axis: str) -> dict[str, Any]:
if df is None:
return {"present": False, "axis": axis, "shape": [0, 0], "columns": []}
return {
"present": True,
"axis": axis,
"shape": [int(df.shape[0]), int(df.shape[1])],
"columns": _columns(df),
"dtypes": {str(k): str(v) for k, v in df.dtypes.items()},
"index_name": None if df.index.name is None else str(df.index.name),
}
def _array_mapping_field(mapping: Mapping[str, Any], *, axis: str) -> dict[str, Any]:
return {
"present": bool(mapping),
"axis": axis,
"keys": sorted(map(str, mapping.keys())),
"shapes": {str(k): _shape(v) for k, v in mapping.items()},
"dtypes": {str(k): str(getattr(v, "dtype", "")) for k, v in mapping.items()},
}
def _linked_anndata_field(cdata, *, include_linked_adata: bool, max_catalog_items: int) -> dict[str, Any]:
meta = dict(cdata.uns.get("linked_anndata", {}) or {})
out = {
"present": False,
"path": _json_scalar(meta.get("path")),
"n_obs": _maybe_int(meta.get("n_obs")),
"n_vars": _maybe_int(meta.get("n_vars")),
"cell_id_axis": _json_scalar(meta.get("cell_id_axis")),
}
if not include_linked_adata:
return out
adata = cdata.linked_adata
if adata is None:
return out
out.update({
"present": True,
"shape": [int(adata.n_obs), int(adata.n_vars)],
"x_type": type(adata.X).__name__,
"obs_columns": _columns(adata.obs),
"var_columns": _columns(adata.var),
"layers": sorted(map(str, adata.layers.keys())),
"obsm": sorted(map(str, adata.obsm.keys())),
"uns_keys": sorted(map(str, adata.uns.keys())),
"obs_names": _catalog(_as_str_list(adata.obs_names), max_catalog_items),
"var_names": _catalog(_as_str_list(adata.var_names), max_catalog_items),
})
return out
def _cell_type_counts(cdata) -> dict[str, int]:
if len(cdata.cells) == 0 or "cell_type" not in cdata.cells.columns:
return {}
counts = cdata.cells["cell_type"].astype(str).value_counts(dropna=False)
return {str(k): int(v) for k, v in counts.items()}
def _known_missing(cdata, linked: Mapping[str, Any]) -> list[str]:
missing = []
if "if_mean" not in cdata.cellm:
missing.append("cellm['if_mean'] per-cell IF mean matrix")
if linked.get("present") and not cdata.uns.get("raw_rna_spots"):
missing.append("raw RNA seqFISH spot geometry as a first-class ChromData component")
if not cdata.uns.get("scrna_reference"):
missing.append("scRNA reference matrix for external expression comparison")
if not cdata.uns.get("gene_annotation"):
missing.append("gene annotation cache for gene-neighborhood analyses")
return missing
def _catalog(values: list[str], max_items: int) -> dict[str, Any]:
values = list(values)
return {
"n": len(values),
"values": values[:max_items],
"truncated": len(values) > max_items,
"sha1": _hash_names(values),
}
def _format_catalog(catalog: Mapping[str, Any], max_items: int) -> str:
values = list(catalog.get("values", []))[:max_items]
suffix = " ..." if catalog.get("truncated") or catalog.get("n", 0) > len(values) else ""
return f"{catalog.get('n', len(values))} [{', '.join(map(str, values))}{suffix}]"
def _format_cell_types(catalog: Mapping[str, Any], max_items: int) -> str:
counts = catalog.get("counts", {}) or {}
items = list(counts.items())[:max_items]
suffix = " ..." if len(counts) > len(items) else ""
return f"{len(counts)} [" + ", ".join(f"{k}={v}" for k, v in items) + suffix + "]"
def _columns(df: pd.DataFrame | None) -> list[str]:
if df is None:
return []
return [str(c) for c in df.columns]
def _shape(value: Any) -> list[int]:
return [int(x) for x in getattr(value, "shape", [])]
def _as_str_list(values: Any) -> list[str]:
return [str(v) for v in list(values)]
def _maybe_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except Exception:
return None
def _json_scalar(value: Any) -> Any:
if isinstance(value, np.generic):
return value.item()
if isinstance(value, bytes):
return value.decode("utf-8")
if isinstance(value, (str, int, float, bool)) or value is None:
return value
return str(value)
def _json_default(value: Any) -> Any:
if isinstance(value, np.generic):
return value.item()
if isinstance(value, np.ndarray):
return value.tolist()
if isinstance(value, (pd.Index, pd.Series)):
return value.tolist()
if isinstance(value, bytes):
return value.decode("utf-8")
return str(value)
def _json_roundtrip(value: Any) -> dict[str, Any]:
return json.loads(json.dumps(value, default=_json_default))
def _hash_names(values: list[str]) -> str:
h = hashlib.sha1()
for value in values:
h.update(value.encode("utf-8"))
h.update(b"\0")
return h.hexdigest()
def _hash_json(value: Mapping[str, Any]) -> str:
clone = dict(value)
clone.pop("schema_hash", None)
clone.pop("created_utc", None)
payload = json.dumps(clone, sort_keys=True, separators=(",", ":"), default=_json_default)
return hashlib.sha1(payload.encode("utf-8")).hexdigest()