import json
import os
from pathlib import Path
import sys
import warnings
import numpy as np
from dipy.core.gradients import unique_bvals_tolerance
from dipy.direction.peaks import PeaksAndMetrics
from dipy.reconst._force_search import search_inner_product as _cython_search
from dipy.reconst.base import ReconstFit, ReconstModel
from dipy.reconst.multi_voxel import multi_voxel_fit
from dipy.reconst.shm import sf_to_sh
from dipy.sims.force import (
generate_force_simulations,
get_default_diffusivity_config,
load_force_simulations,
save_force_simulations,
)
from dipy.utils.logging import logger
# Named constants
EPSILON = 1e-12
def _get_force_cache_dir():
"""Return the FORCE simulation cache directory inside .dipy.
Uses ``DIPY_HOME`` environment variable if set, otherwise defaults
to ``~/.dipy/force_simulations``.
Returns
-------
cache_dir : Path
Path to the cache directory (created if it does not exist).
"""
if "DIPY_HOME" in os.environ:
dipy_home = Path(os.environ["DIPY_HOME"])
else:
dipy_home = Path("~").expanduser() / ".dipy"
cache_dir = dipy_home / "force_simulations"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def _gtab_matches(entry_bvals, entry_bvecs, gtab, *, bval_tol=10.0, bvec_tol=1e-3):
"""Check whether stored bvals/bvecs match a GradientTable.
Parameters
----------
entry_bvals : list
Stored b-values from the cache registry.
entry_bvecs : list of list
Stored b-vectors from the cache registry.
gtab : GradientTable
Gradient table to compare against.
bval_tol : float, optional
Absolute tolerance for b-value comparison.
bvec_tol : float, optional
Absolute tolerance for b-vector coordinate comparison.
Returns
-------
match : bool
True if the stored and passed bvals/bvecs agree within tolerance.
"""
stored_bvals = np.asarray(entry_bvals, dtype=np.float64)
stored_bvecs = np.asarray(entry_bvecs, dtype=np.float64)
current_bvals = np.asarray(gtab.bvals, dtype=np.float64)
current_bvecs = np.asarray(gtab.bvecs, dtype=np.float64)
if stored_bvals.shape != current_bvals.shape:
return False
if stored_bvecs.shape != current_bvecs.shape:
return False
return np.allclose(stored_bvals, current_bvals, atol=bval_tol) and np.allclose(
stored_bvecs, current_bvecs, atol=bvec_tol
)
def _diffusivity_matches(entry_config, current_config):
"""Check whether two diffusivity configurations are equivalent.
Parameters
----------
entry_config : dict
Stored diffusivity configuration.
current_config : dict
Current diffusivity configuration.
Returns
-------
match : bool
True if all keys and values are identical.
"""
if set(entry_config.keys()) != set(current_config.keys()):
return False
for key in entry_config:
stored = entry_config[key]
current = current_config[key]
# Both may be lists/tuples (ranges) or scalars
if isinstance(stored, (list, tuple)):
if not isinstance(current, (list, tuple)):
return False
if len(stored) != len(current):
return False
if not all(np.isclose(s, c) for s, c in zip(stored, current)):
return False
else:
if not np.isclose(stored, current):
return False
return True
def _load_cache_registry(cache_dir):
"""Load the cache registry JSON from *cache_dir*.
Returns an empty list if the file does not exist yet.
"""
registry_path = cache_dir / "cache_registry.json"
if registry_path.exists():
with open(registry_path, "r") as f:
return json.load(f)
return []
def _save_cache_registry(cache_dir, registry):
"""Persist *registry* as JSON in *cache_dir*."""
registry_path = cache_dir / "cache_registry.json"
with open(registry_path, "w") as f:
json.dump(registry, f, indent=2)
def _locked_registry_update(cache_dir, update_fn):
"""Read-modify-write the cache registry under an exclusive file lock.
Parameters
----------
cache_dir : Path
Cache directory.
update_fn : callable
Function that receives the current registry list and returns the
updated list.
"""
lock_path = cache_dir / "cache_registry.lock"
with open(lock_path, "w") as lock_fh:
if sys.platform == "win32":
import msvcrt
msvcrt.locking(lock_fh.fileno(), msvcrt.LK_LOCK, 1)
try:
registry = _load_cache_registry(cache_dir)
registry = update_fn(registry)
_save_cache_registry(cache_dir, registry)
finally:
lock_fh.seek(0)
msvcrt.locking(lock_fh.fileno(), msvcrt.LK_UNLCK, 1)
else:
import fcntl
fcntl.flock(lock_fh, fcntl.LOCK_EX)
try:
registry = _load_cache_registry(cache_dir)
registry = update_fn(registry)
_save_cache_registry(cache_dir, registry)
finally:
fcntl.flock(lock_fh, fcntl.LOCK_UN)
def _find_cached_simulation(cache_dir, gtab, diffusivity_config, num_simulations):
"""Search the registry for a simulation matching the given parameters.
Parameters
----------
cache_dir : Path
Cache directory.
gtab : GradientTable
Gradient table.
diffusivity_config : dict
Diffusivity ranges used for generation.
num_simulations : int
Number of simulations requested.
Returns
-------
path : str or None
Path to the cached ``.npz`` file, or None if no match found.
"""
registry = _load_cache_registry(cache_dir)
for entry in registry:
if entry["num_simulations"] != num_simulations:
continue
if not _gtab_matches(entry["bvals"], entry["bvecs"], gtab):
continue
if not _diffusivity_matches(entry["diffusivity_config"], diffusivity_config):
continue
candidate = cache_dir / entry["filename"]
if candidate.exists():
return str(candidate)
return None
def _register_cached_simulation(
cache_dir, gtab, diffusivity_config, num_simulations, filename
):
"""Add a new entry to the cache registry.
Parameters
----------
cache_dir : Path
Cache directory.
gtab : GradientTable
Gradient table.
diffusivity_config : dict
Diffusivity ranges.
num_simulations : int
Number of simulations.
filename : str
Filename of the saved ``.npz`` inside *cache_dir*.
"""
# Convert numpy types to plain Python for JSON serialisation
def _to_json(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.floating, np.integer)):
return obj.item()
if isinstance(obj, tuple):
return list(obj)
return obj
config_json = {}
for k, v in diffusivity_config.items():
config_json[k] = _to_json(v)
entry = {
"bvals": np.asarray(gtab.bvals, dtype=np.float64).tolist(),
"bvecs": np.asarray(gtab.bvecs, dtype=np.float64).tolist(),
"diffusivity_config": config_json,
"num_simulations": int(num_simulations),
"filename": filename,
}
def _append(registry):
registry.append(entry)
return registry
_locked_registry_update(cache_dir, _append)
[docs]
class SignalIndex:
"""Index for inner product similarity search.
Uses optimized Cython BLAS for fast matrix multiplication
and streaming heap for memory-efficient top-k selection.
Parameters
----------
d : int
Dimension of vectors.
"""
def __init__(self, d):
if d <= 0:
raise ValueError(f"Dimension must be positive, got {d}")
self.d = int(d)
self.ntotal = 0
self._xb = None
[docs]
def add(self, x):
"""Add vectors to the index.
Parameters
----------
x : array-like (n, d)
Vectors to add, will be converted to float32 C-contiguous.
Notes
-----
Each call reallocates the internal array via ``np.vstack``. This
method is designed for a single bulk load; repeated small ``add``
calls will exhibit O(n²) memory allocation cost.
"""
x = np.ascontiguousarray(x, dtype=np.float32)
if x.ndim == 1:
if len(x) != self.d:
raise ValueError(
f"Vector dimension {len(x)} != index dimension {self.d}"
)
x = x.reshape(1, -1)
if x.ndim != 2:
raise ValueError(f"Expected 1D or 2D array, got {x.ndim}D")
if x.shape[1] != self.d:
raise ValueError(
f"Vector dimension {x.shape[1]} != index dimension {self.d}"
)
if self._xb is None:
self._xb = x.copy()
else:
self._xb = np.vstack([self._xb, x])
self.ntotal = len(self._xb)
[docs]
def search(self, x, k):
"""Search for k nearest neighbors by inner product.
Parameters
----------
x : array-like (n, d) or (d,)
Query vectors.
k : int
Number of neighbors.
Returns
-------
distances : ndarray (n, k)
Inner products (descending order).
indices : ndarray (n, k)
Neighbor indices.
"""
if self.ntotal == 0:
raise RuntimeError("Cannot search empty index")
x = np.ascontiguousarray(x, dtype=np.float32)
if x.ndim == 1:
if len(x) != self.d:
raise ValueError(
f"Query dimension {len(x)} != index dimension {self.d}"
)
x = x.reshape(1, -1)
if x.ndim != 2:
raise ValueError(f"Expected 1D or 2D array, got {x.ndim}D")
if x.shape[1] != self.d:
raise ValueError(
f"Query dimension {x.shape[1]} != index dimension {self.d}"
)
if k <= 0:
raise ValueError(f"k must be positive, got {k}")
if k > self.ntotal:
warnings.warn(
f"k={k} exceeds index size ({self.ntotal}); "
f"clamping to {self.ntotal}.",
UserWarning,
stacklevel=2,
)
k = self.ntotal
# Use optimized Cython search (SciPy BLAS + streaming heap)
distances, indices = _cython_search(x, self._xb, k)
return distances, indices
[docs]
def reset(self):
"""Remove all vectors from the index."""
self._xb = None
self.ntotal = 0
def __repr__(self):
return f"SignalIndex(d={self.d}, ntotal={self.ntotal})"
[docs]
def normalize_signals(signals):
"""L2-normalize signal array for cosine similarity search.
Parameters
----------
signals : ndarray (N, M)
Signal array with N samples and M measurements.
Returns
-------
normalized : ndarray (N, M)
L2-normalized signals.
"""
signals = np.asarray(signals, dtype=np.float32)
norms = np.linalg.norm(signals, axis=1, keepdims=True)
norms[norms == 0] = 1.0
return np.ascontiguousarray(signals / norms)
[docs]
def create_signal_index(signals_norm):
"""Create index for cosine similarity search.
Parameters
----------
signals_norm : ndarray (N, M)
L2-normalized library signals.
Returns
-------
index : SignalIndex
Search index.
"""
dimension = signals_norm.shape[1]
index = SignalIndex(dimension)
index.add(signals_norm)
return index
[docs]
def softmax_stable(x, *, axis=1):
"""Numerically stable softmax.
Parameters
----------
x : ndarray
Input array.
axis : int, optional
Axis along which to compute softmax.
Returns
-------
softmax : ndarray
Softmax probabilities.
"""
x = x - np.max(x, axis=axis, keepdims=True)
ex = np.exp(x)
return ex / np.sum(ex, axis=axis, keepdims=True)
[docs]
def compute_uncertainty_ambiguity(scores):
"""Compute uncertainty and ambiguity metrics from match scores.
Parameters
----------
scores : ndarray (N, K)
Penalized scores for K neighbors.
Returns
-------
uncertainty : ndarray (N,)
IQR of scores.
ambiguity : ndarray (N,)
Fraction above half-max.
"""
p75 = np.percentile(scores, 75, axis=1)
p25 = np.percentile(scores, 25, axis=1)
uncertainty = (p75 - p25).astype(np.float32)
s_max = np.max(scores, axis=1)
s_min = np.min(scores, axis=1)
half = 0.5 * (s_max + s_min)
ambiguity = (np.sum(scores > half[:, None], axis=1) / scores.shape[1]).astype(
np.float32
)
return uncertainty, ambiguity
MICRO_PARAMS = (
"fa",
"md",
"rd",
"wm_fraction",
"gm_fraction",
"csf_fraction",
"num_fibers",
"dispersion",
"nd",
"ufa_voxel",
"ak",
"rk",
"mk",
"kfa",
)
def _weighted_percentile(vals, weights, q):
"""Batch weighted percentile.
Parameters
----------
vals : ndarray (N, K)
Parameter values across K neighbors per voxel.
weights : ndarray (N, K)
Posterior weights.
q : float
Quantile in [0, 1].
Returns
-------
result : ndarray (N,)
Weighted percentile for each voxel.
"""
idx = np.argsort(vals, axis=1)
sv = np.take_along_axis(vals, idx, axis=1)
sw = np.take_along_axis(weights, idx, axis=1)
cw = np.cumsum(sw, axis=1)
cw = cw / (cw[:, -1:] + EPSILON)
j = np.argmax(cw >= q, axis=1)
return sv[np.arange(sv.shape[0]), j]
def _fwhm_kde_batch(vals, weights, *, n_grid=100, batch_size=2000):
"""Batch FWHM via weighted KDE.
Estimates the full width at half maximum of the weighted posterior
density for each voxel using a Gaussian kernel density estimate.
Parameters
----------
vals : ndarray (N, K)
Parameter values across K neighbors.
weights : ndarray (N, K)
Posterior weights.
n_grid : int, optional
Number of grid points for KDE evaluation.
batch_size : int, optional
Internal batch size to limit memory usage.
Returns
-------
fwhm : ndarray (N,)
FWHM in parameter units.
"""
N, K = vals.shape
fwhm = np.empty(N, dtype=np.float32)
for start in range(0, N, batch_size):
end = min(start + batch_size, N)
vals_batch = vals[start:end].astype(np.float32)
weights_batch = weights[start:end].astype(np.float32)
n = end - start
lo = vals_batch.min(axis=1)
hi = vals_batch.max(axis=1)
span = hi - lo
pad = 0.15 * (span + EPSILON)
lo_p, hi_p = lo - pad, hi + pad
span_p = hi_p - lo_p
t = np.linspace(0, 1, n_grid, dtype=np.float32)
grid = lo_p[:, None] + t[None, :] * span_p[:, None]
bw = span_p / max(K**0.5, 4.0)
bw = np.maximum(bw, EPSILON)
diff = (grid[:, :, None] - vals_batch[:, None, :]) / bw[:, None, None]
density = np.sum(np.exp(-0.5 * diff**2) * weights_batch[:, None, :], axis=2)
peak = density.max(axis=1)
half_peak = 0.5 * peak
above = density >= half_peak[:, None]
any_above = np.any(above, axis=1)
first = np.argmax(above, axis=1)
last = n_grid - 1 - np.argmax(above[:, ::-1], axis=1)
f = grid[np.arange(n), last] - grid[np.arange(n), first]
f[~any_above] = 0.0
fwhm[start:end] = f
return fwhm
[docs]
def compute_microstructure_uncertainty_ambiguity(vals, weights, prior_range):
"""Compute per-microstructure uncertainty and ambiguity.
Uncertainty is the weighted interquartile range of the posterior density,
normalized by the prior range. Ambiguity is the full width at half
maximum (FWHM) of the weighted posterior density, similarly normalized.
Both are in [0, 1] where 0 means no spread and 1 means the posterior
spans the entire prior range.
Parameters
----------
vals : ndarray (N, K)
Parameter values for K neighbors per voxel.
weights : ndarray (N, K)
Posterior weights (softmax).
prior_range : float
Max minus min of this parameter across the simulation library.
Returns
-------
uncertainty : ndarray (N,)
Weighted IQR / prior_range (fraction in [0, 1]).
ambiguity : ndarray (N,)
FWHM / prior_range (fraction in [0, 1]).
"""
q75 = _weighted_percentile(vals, weights, 0.75)
q25 = _weighted_percentile(vals, weights, 0.25)
uncertainty = ((q75 - q25) / (prior_range + EPSILON)).astype(np.float32)
fwhm = _fwhm_kde_batch(vals, weights)
ambiguity = (fwhm / (prior_range + EPSILON)).astype(np.float32)
return uncertainty, ambiguity
[docs]
def postprocess_peaks(preds, target_sphere, fracs):
original_shape = preds.shape[:-1]
preds_flat = preds.reshape(-1, preds.shape[-1])
fracs_flat = fracs.reshape(-1, fracs.shape[-1])
n_voxels = preds_flat.shape[0]
vertices = target_sphere.vertices
# Initialize outputs using the total number of voxels
peaks_output = np.zeros((n_voxels, 5, 3), dtype=np.float32)
peak_indices = np.full((n_voxels, 5), -1, dtype=np.int32)
peak_vals = np.zeros((n_voxels, 5), dtype=np.float32)
for i in range(n_voxels):
mask = preds_flat[i] == 1
coords = vertices[mask]
indices = np.where(mask)[0]
num = min(len(coords), 5)
if num > 0:
peaks_output[i, :num] = coords[:num]
peak_indices[i, :num] = indices[:num]
num_fracs = min(5, fracs_flat[i].shape[0])
peak_vals[i, :num_fracs] = fracs_flat[i][:num_fracs]
peaks_output = peaks_output.reshape((*original_shape, 5, 3))
peak_indices = peak_indices.reshape((*original_shape, 5))
peak_vals = peak_vals.reshape((*original_shape, 5))
return peaks_output, peak_indices, peak_vals
[docs]
class FORCEModel(ReconstModel):
"""FORCE reconstruction model for signal matching based microstructure."""
def __init__(
self,
gtab,
*,
simulations=None,
penalty=1e-5,
n_neighbors=50,
use_posterior=False,
posterior_beta=2000.0,
compute_odf=False,
verbose=False,
):
r"""
FORCE (FORward modeling for Complex microstructure Estimation) model
:footcite:p:`Shah2025`.
FORCE is a forward modeling paradigm that reframes how diffusion data
is analyzed. Instead of inverting the measured signal, FORCE simulates
a large set of biologically plausible intra-voxel fiber configurations
and tissue compositions. It then identifies the best-matching simulation
for each voxel by operating directly in the signal space.
Parameters
----------
gtab : GradientTable
Gradient table.
simulations : dict or None, optional
Pre-computed FORCE simulations with signals and parameters.
If None, call generate() to create simulations.
penalty : float, optional
Penalty weight for fiber complexity.
n_neighbors : int, optional
Number of neighbors for matching.
use_posterior : bool, optional
Use posterior averaging instead of best match.
posterior_beta : float, optional
Softmax temperature for posterior.
compute_odf : bool, optional
Compute posterior ODF maps.
verbose : bool, optional
Show progress bar and status messages.
Notes
-----
The fit method uses the @multi_voxel_fit decorator which supports
parallel execution. Pass `engine` and `n_jobs` kwargs to the fit method:
Available engines: "serial", "ray", "joblib", "dask".
References
----------
.. footbibliography::
"""
self.gtab = gtab
self.simulations = simulations
self.penalty = penalty
self.n_neighbors = n_neighbors
self.use_posterior = use_posterior
self.posterior_beta = posterior_beta
self.compute_odf = compute_odf
self.verbose = verbose
self._index = None
self._penalty_array = None
if simulations is not None:
self._prepare_library()
[docs]
def generate(
self,
*,
num_simulations=500000,
output_path=None,
num_cpus=1,
wm_threshold=0.5,
tortuosity=False,
odi_range=(0.01, 0.3),
diffusivity_config=None,
compute_dti=True,
compute_dki=False,
verbose=False,
use_cache=True,
):
"""Generate simulations for matching.
When ``output_path`` is ``None`` and ``use_cache`` is ``True``,
simulations are cached in ``~/.dipy/force_simulations/`` (or
``$DIPY_HOME``). A registry file (``cache_registry.json``) keeps
track of the bvals, bvecs, diffusivity configuration and number of
simulations for each cached file. If a cached simulation that matches
the current gradient table (within tolerance) and diffusivity
configuration already exists, it is loaded from disk and generation
is skipped.
Set ``use_cache=False`` to force regeneration even when a matching
cached simulation exists.
Parameters
----------
num_simulations : int, optional
Number of simulated voxels.
output_path : str, optional
Path to save simulations (.npz). When None, saves to
``~/.dipy/force_simulations/`` and uses caching.
num_cpus : int, optional
Number of CPU cores for parallel processing.
wm_threshold : float, optional
Minimum WM fraction to include fiber labels.
tortuosity : bool, optional
Use tortuosity constraint for perpendicular diffusivity.
odi_range : tuple, optional
(min, max) orientation dispersion index range.
diffusivity_config : dict, optional
Custom diffusivity ranges.
compute_dti : bool, optional
Compute DTI metrics (FA, MD, RD).
compute_dki : bool, optional
Compute DKI metrics (AK, RK, MK, KFA).
verbose : bool, optional
Enable progress output.
use_cache : bool, optional
Whether to use cached simulations when ``output_path`` is
None. Set to ``False`` to always regenerate.
Returns
-------
self : FORCEModel
Model with simulations loaded.
"""
b0_thr = getattr(self.gtab, "b0_threshold", 50)
unique_bvals = unique_bvals_tolerance(self.gtab.bvals, tol=50)
n_nonzero_shells = int(np.sum(unique_bvals > b0_thr))
if compute_dki and n_nonzero_shells < 2:
warnings.warn(
f"compute_dki=True but only {n_nonzero_shells} non-zero "
"b-value shell found (need at least 2). "
"Disabling DKI computation.",
UserWarning,
stacklevel=2,
)
compute_dki = False
elif not compute_dki and n_nonzero_shells >= 2:
logger.info(
f"Found {n_nonzero_shells} non-zero b-value shells. "
"You can compute DKI metrics (AK, RK, MK, KFA) by "
"setting compute_dki=True."
)
# Resolve the diffusivity config that will actually be used
resolved_config = (
diffusivity_config
if diffusivity_config is not None
else get_default_diffusivity_config()
)
# --- Cache logic when no explicit output_path is given ----------
if output_path is None and use_cache:
cache_dir = _get_force_cache_dir()
cached = _find_cached_simulation(
cache_dir,
self.gtab,
resolved_config,
num_simulations,
)
if cached is not None:
if verbose:
print(f"[FORCE] Loading cached simulations from {cached}")
self.simulations = load_force_simulations(cached)
self._prepare_library()
return self
# --- Generate new simulations -----------------------------------
self.simulations = generate_force_simulations(
self.gtab,
num_simulations=num_simulations,
num_cpus=num_cpus,
wm_threshold=wm_threshold,
tortuosity=tortuosity,
odi_range=odi_range,
diffusivity_config=diffusivity_config,
compute_dti=compute_dti,
compute_dki=compute_dki,
verbose=verbose,
)
if output_path is not None:
save_force_simulations(self.simulations, output_path)
else:
# Save into the .dipy cache and register.
# filename is generated inside the lock to avoid races between
# concurrent processes reading the same registry length.
cache_dir = _get_force_cache_dir()
filename_holder = {}
def _append_and_name(registry):
idx = len(registry)
fname = f"force_sim_{idx}.npz"
filename_holder["filename"] = fname
return registry # entry added by _register_cached_simulation
_locked_registry_update(cache_dir, _append_and_name)
filename = filename_holder["filename"]
save_force_simulations(self.simulations, str(cache_dir / filename))
_register_cached_simulation(
cache_dir,
self.gtab,
resolved_config,
num_simulations,
filename,
)
if verbose:
print(f"[FORCE] Cached simulations to {cache_dir / filename}")
self._prepare_library()
return self
[docs]
def load(self, input_path):
"""Load pre-computed simulations from file.
Parameters
----------
input_path : str
Path to simulations file (.npz).
Returns
-------
self : FORCEModel
Model with simulations loaded.
"""
self.simulations = load_force_simulations(input_path)
self._prepare_library()
return self
def _prepare_library(self):
"""Prepare library for matching."""
signals = self.simulations["signals"]
# Normalize library signals
lib_norm = np.linalg.norm(signals, axis=1, keepdims=True)
lib_norm[lib_norm == 0] = 1.0
signals_norm = np.ascontiguousarray((signals / lib_norm).astype(np.float32))
# Build index
self._index = create_signal_index(signals_norm)
# Penalty array
num_fibers = self.simulations.get(
"num_fibers", np.zeros(len(signals), dtype=np.float32)
)
self._penalty_array = (self.penalty * num_fibers).astype(np.float32)
# Prior ranges for microstructure uncertainty normalization
d = self.simulations
self._prior_ranges = {}
for param in MICRO_PARAMS:
if param in d:
arr = d[param]
self._prior_ranges[param] = float(arr.max() - arr.min())
@staticmethod
def _fetch_params_batched(lib_idx, d):
"""Vectorised parameter look-up for best-match indices.
Parameters
----------
lib_idx : ndarray (N,)
Library indices of the best match per voxel.
d : dict
Simulation dictionary.
Returns
-------
params : dict of ndarray
"""
params = {
"fa": d["fa"][lib_idx].astype(np.float32),
"md": d["md"][lib_idx].astype(np.float32),
"rd": d["rd"][lib_idx].astype(np.float32),
"wm_fraction": d["wm_fraction"][lib_idx].astype(np.float32),
"gm_fraction": d["gm_fraction"][lib_idx].astype(np.float32),
"csf_fraction": d["csf_fraction"][lib_idx].astype(np.float32),
"num_fibers": d["num_fibers"][lib_idx].astype(np.float32),
"dispersion": d["dispersion"][lib_idx].astype(np.float32),
"nd": d["nd"][lib_idx].astype(np.float32),
"labels": d["labels"][lib_idx].astype(np.int8),
"fracs": d["fraction_array"][lib_idx].astype(np.float32),
}
if "ufa_wm" in d:
params["ufa_wm"] = d["ufa_wm"][lib_idx].astype(np.float32)
params["ufa_voxel"] = d["ufa_voxel"][lib_idx].astype(np.float32)
if "ak" in d:
params["ak"] = d["ak"][lib_idx].astype(np.float32)
params["rk"] = d["rk"][lib_idx].astype(np.float32)
params["mk"] = d["mk"][lib_idx].astype(np.float32)
params["kfa"] = d["kfa"][lib_idx].astype(np.float32)
params["odf"] = d["odfs"][lib_idx].astype(np.float32) if "odfs" in d else None
params["predicted_signal"] = d["signals"][lib_idx].astype(np.float32)
return params
@staticmethod
def _posterior_params_batched(neighbors, W, d, lib_idx):
"""Vectorised posterior-averaging over neighbours.
Parameters
----------
neighbors : ndarray (N, K)
Neighbour indices.
W : ndarray (N, K)
Posterior weights.
d : dict
Simulation dictionary.
lib_idx : int
Index of the exact match in the simulation
Returns
-------
params : dict of ndarray
"""
def _wavg(field):
return np.sum(W * d[field][neighbors], axis=1).astype(np.float32)
params = {
"fa": _wavg("fa"),
"md": _wavg("md"),
"rd": _wavg("rd"),
"wm_fraction": _wavg("wm_fraction"),
"gm_fraction": _wavg("gm_fraction"),
"csf_fraction": _wavg("csf_fraction"),
"num_fibers": _wavg("num_fibers"),
"dispersion": _wavg("dispersion"),
"nd": _wavg("nd"),
"labels": d["labels"][lib_idx].astype(np.int8),
"fracs": d["fraction_array"][lib_idx].astype(np.float32),
}
if "ufa_wm" in d:
params["ufa_wm"] = _wavg("ufa_wm")
params["ufa_voxel"] = _wavg("ufa_voxel")
if "ak" in d:
params["ak"] = _wavg("ak")
params["rk"] = _wavg("rk")
params["mk"] = _wavg("mk")
params["kfa"] = _wavg("kfa")
# Posterior ODF
if "odfs" in d:
K = neighbors.shape[1]
odf = np.zeros((neighbors.shape[0], d["odfs"].shape[1]), dtype=np.float32)
for kk in range(K):
odf_k = d["odfs"][neighbors[:, kk]].astype(np.float32)
odf_k /= np.max(odf_k, axis=1, keepdims=True) + EPSILON
odf += W[:, kk : kk + 1] * odf_k
odf /= np.max(odf, axis=1, keepdims=True) + EPSILON
params["odf"] = odf
else:
params["odf"] = None
# Posterior mean signal
params["predicted_signal"] = posterior_mean_signal(d["signals"], W, neighbors)
return params
@multi_voxel_fit(
batched=True,
shared_obj=("_penalty_array", "_index", "simulations"),
chunk_size={"serial": 10_000, "ray": "auto"},
)
def fit(self, data, *, mask=None, **kwargs):
"""Fit model to data.
Parameters
----------
data : ndarray
Diffusion data for a single voxel (1D) or multiple voxels (ND).
mask : ndarray, optional
Brain mask (for multi-voxel data).
**kwargs : dict
Additional arguments passed to multi_voxel_fit decorator:
- engine : str, optional
Parallel engine: "serial", "ray", "joblib", "dask".
- n_jobs : int, optional
Number of parallel jobs.
- verbose : bool, optional
Show progress bar.
Returns
-------
fit : FORCEFit or ndarray of FORCEFit
Fitted model for a single voxel (1-D input) or an object array
of fitted models for a batch of voxels (2-D input).
Notes
-----
This method is decorated with @multi_voxel_fit(batched=True)
which handles multi-voxel dispatch, mask application, and aggregation
into a MultiVoxelFit. The method itself handles both 1-D
(single voxel) and 2-D (batch) inputs directly.
For parallel execution, use engine="ray", n_jobs=4 arguments in model
fit() call.
**Memory warning (joblib / dask engines):** When engine="joblib"
or engine="dask", the full simulation library (including the
signal matrix and search index, ~120-400 MB for 100k simulations) is
pickled and sent to *every* worker chunk. With 8 workers this can
consume several gigabytes of RAM. For num_simulations > ~10 000
use engine="ray" instead, which places the library in a shared
object store and avoids redundant copies across workers.
"""
if self.simulations is None:
raise RuntimeError(
"No simulations loaded. Call generate() or provide simulations."
)
if self._index is None:
raise RuntimeError(
"Signal index is not prepared. Call _prepare_library() or "
"reload simulations via generate() or load()."
)
single = data.ndim == 1
data2d = data.reshape(1, -1) if single else data
data2d = np.ascontiguousarray(data2d, dtype=np.float32)
norms = np.linalg.norm(data2d, axis=1, keepdims=True).astype(np.float32)
norms[norms == 0] = 1.0
query_norm = np.ascontiguousarray(data2d / norms)
D, neighbors = self._index.search(query_norm, k=self.n_neighbors)
S = D - self._penalty_array[neighbors]
U, A = compute_uncertainty_ambiguity(S)
d = self.simulations
n_vox = data2d.shape[0]
best = np.argmax(S, axis=1)
lib_idx = neighbors[np.arange(n_vox), best]
W = softmax_stable(self.posterior_beta * S, axis=1)
if self.use_posterior:
entropy = -np.sum(W * np.log(W + EPSILON), axis=1)
params_arrays = self._posterior_params_batched(neighbors, W, d, lib_idx)
params_arrays["uncertainty"] = U
params_arrays["ambiguity"] = A
params_arrays["entropy"] = entropy.astype(np.float32)
else:
params_arrays = self._fetch_params_batched(lib_idx, d)
params_arrays["uncertainty"] = U
params_arrays["ambiguity"] = A
params_arrays["entropy"] = np.full(n_vox, np.nan, dtype=np.float32)
# Per-microstructure uncertainty and ambiguity from posterior density
prior_ranges = getattr(self, "_prior_ranges", {})
for param in MICRO_PARAMS:
if param in d and param in prior_ranges:
vals = d[param][neighbors].astype(np.float32)
u, a = compute_microstructure_uncertainty_ambiguity(
vals, W, prior_ranges[param]
)
params_arrays[f"uncertainty_{param}"] = u
params_arrays[f"ambiguity_{param}"] = a
if kwargs.pop("_raw", False):
return params_arrays
fits = np.empty(n_vox, dtype=object)
keys = list(params_arrays.keys())
for i in range(n_vox):
p = {}
for k in keys:
val = params_arrays[k]
if val is None:
p[k] = None
else:
v = val[i]
if isinstance(v, np.ndarray) and v.ndim == 0:
p[k] = float(v)
elif isinstance(v, (np.floating, np.integer)):
p[k] = float(v)
else:
p[k] = v
entropy_val = p.get("entropy", 0.0)
if entropy_val is not None and np.isnan(entropy_val):
p["entropy"] = None
fits[i] = FORCEFit(None, p)
return fits[0] if single else fits
[docs]
class FORCEFit(ReconstFit):
"""FORCE model fit results for a single voxel."""
def __init__(self, model, params):
"""Initialize a FORCEFit class instance."""
if (
"entropy" in params
and params["entropy"] is not None
and np.isnan(params["entropy"])
):
params["entropy"] = None
self.model = model
self._params = params
@property
def fa(self):
"""Fractional anisotropy."""
return self._params["fa"]
@property
def md(self):
"""Mean diffusivity."""
return self._params["md"]
@property
def rd(self):
"""Radial diffusivity."""
return self._params["rd"]
@property
def wm_fraction(self):
"""White matter fraction."""
return self._params["wm_fraction"]
@property
def gm_fraction(self):
"""Gray matter fraction."""
return self._params["gm_fraction"]
@property
def csf_fraction(self):
"""CSF fraction."""
return self._params["csf_fraction"]
@property
def num_fibers(self):
"""Number of fibers."""
return self._params["num_fibers"]
@property
def dispersion(self):
"""Orientation dispersion."""
return self._params["dispersion"]
@property
def nd(self):
"""Neurite density."""
return self._params["nd"]
@property
def ufa_wm(self):
"""microFA in white matter."""
return self._params.get("ufa_wm", None)
@property
def ufa_voxel(self):
"""Voxel-averaged microFA."""
return self._params.get("ufa_voxel", None)
@property
def ak(self):
"""Axial kurtosis (DKI)."""
return self._params.get("ak", None)
@property
def rk(self):
"""Radial kurtosis (DKI)."""
return self._params.get("rk", None)
@property
def mk(self):
"""Mean kurtosis (DKI)."""
return self._params.get("mk", None)
@property
def kfa(self):
"""Kurtosis fractional anisotropy (DKI)."""
return self._params.get("kfa", None)
@property
def odf(self):
"""Orientation distribution function."""
return self._params.get("odf", None)
@property
def predicted_signal(self):
"""Predicted signal from matched library entry (cleaned DWI)."""
return self._params.get("predicted_signal", None)
@property
def uncertainty(self):
"""Uncertainty (IQR of penalized scores)."""
return self._params["uncertainty"]
@property
def ambiguity(self):
"""Ambiguity (fraction above half-max)."""
return self._params["ambiguity"]
@property
def entropy(self):
"""Entropy (posterior mode only)."""
return self._params.get("entropy", None)
@property
def label(self):
"""Fiber configuration label."""
return self._params.get("labels", None)
@property
def fracs(self):
"""Fiber fractions."""
return self._params.get("fracs", None)
@property
def uncertainty_fa(self):
"""Per-microstructure uncertainty for FA (fraction of prior range)."""
return self._params.get("uncertainty_fa", None)
@property
def ambiguity_fa(self):
"""Per-microstructure ambiguity for FA (fraction of prior range)."""
return self._params.get("ambiguity_fa", None)
@property
def uncertainty_md(self):
"""Per-microstructure uncertainty for MD (fraction of prior range)."""
return self._params.get("uncertainty_md", None)
@property
def ambiguity_md(self):
"""Per-microstructure ambiguity for MD (fraction of prior range)."""
return self._params.get("ambiguity_md", None)
@property
def uncertainty_rd(self):
"""Per-microstructure uncertainty for RD (fraction of prior range)."""
return self._params.get("uncertainty_rd", None)
@property
def ambiguity_rd(self):
"""Per-microstructure ambiguity for RD (fraction of prior range)."""
return self._params.get("ambiguity_rd", None)
@property
def uncertainty_wm_fraction(self):
"""Per-microstructure uncertainty for WM fraction (fraction of prior range)."""
return self._params.get("uncertainty_wm_fraction", None)
@property
def ambiguity_wm_fraction(self):
"""Per-microstructure ambiguity for WM fraction (fraction of prior range)."""
return self._params.get("ambiguity_wm_fraction", None)
@property
def uncertainty_gm_fraction(self):
"""Per-microstructure uncertainty for GM fraction (fraction of prior range)."""
return self._params.get("uncertainty_gm_fraction", None)
@property
def ambiguity_gm_fraction(self):
"""Per-microstructure ambiguity for GM fraction (fraction of prior range)."""
return self._params.get("ambiguity_gm_fraction", None)
@property
def uncertainty_csf_fraction(self):
"""Per-microstructure uncertainty for CSF fraction (fraction of prior range)."""
return self._params.get("uncertainty_csf_fraction", None)
@property
def ambiguity_csf_fraction(self):
"""Per-microstructure ambiguity for CSF fraction (fraction of prior range)."""
return self._params.get("ambiguity_csf_fraction", None)
@property
def uncertainty_num_fibers(self):
"""Per-microstructure uncertainty for num_fibers (fraction of prior range)."""
return self._params.get("uncertainty_num_fibers", None)
@property
def ambiguity_num_fibers(self):
"""Per-microstructure ambiguity for num_fibers (fraction of prior range)."""
return self._params.get("ambiguity_num_fibers", None)
@property
def uncertainty_dispersion(self):
"""Per-microstructure uncertainty for dispersion (fraction of prior range)."""
return self._params.get("uncertainty_dispersion", None)
@property
def ambiguity_dispersion(self):
"""Per-microstructure ambiguity for dispersion (fraction of prior range)."""
return self._params.get("ambiguity_dispersion", None)
@property
def uncertainty_nd(self):
"""Per-microstructure uncertainty for ND (fraction of prior range)."""
return self._params.get("uncertainty_nd", None)
@property
def ambiguity_nd(self):
"""Per-microstructure ambiguity for ND (fraction of prior range)."""
return self._params.get("ambiguity_nd", None)
@property
def uncertainty_ufa_voxel(self):
"""Uncertainty for microFA voxel (fraction of prior range)."""
return self._params.get("uncertainty_ufa_voxel", None)
@property
def ambiguity_ufa_voxel(self):
"""Per-microstructure ambiguity for microFA voxel (fraction of prior range)."""
return self._params.get("ambiguity_ufa_voxel", None)
@property
def uncertainty_ak(self):
"""Per-microstructure uncertainty for AK (fraction of prior range)."""
return self._params.get("uncertainty_ak", None)
@property
def ambiguity_ak(self):
"""Per-microstructure ambiguity for AK (fraction of prior range)."""
return self._params.get("ambiguity_ak", None)
@property
def uncertainty_rk(self):
"""Per-microstructure uncertainty for RK (fraction of prior range)."""
return self._params.get("uncertainty_rk", None)
@property
def ambiguity_rk(self):
"""Per-microstructure ambiguity for RK (fraction of prior range)."""
return self._params.get("ambiguity_rk", None)
@property
def uncertainty_mk(self):
"""Per-microstructure uncertainty for MK (fraction of prior range)."""
return self._params.get("uncertainty_mk", None)
@property
def ambiguity_mk(self):
"""Per-microstructure ambiguity for MK (fraction of prior range)."""
return self._params.get("ambiguity_mk", None)
@property
def uncertainty_kfa(self):
"""Per-microstructure uncertainty for KFA (fraction of prior range)."""
return self._params.get("uncertainty_kfa", None)
@property
def ambiguity_kfa(self):
"""Per-microstructure ambiguity for KFA (fraction of prior range)."""
return self._params.get("ambiguity_kfa", None)
# Resolve forward reference: FORCEModel is defined before FORCEFit.
FORCEModel._fit_class = FORCEFit
[docs]
def compute_entropy(weights):
"""Compute entropy of posterior weights.
Parameters
----------
weights : ndarray (N, K)
Posterior weights for K neighbors.
Returns
-------
entropy : ndarray (N,)
Shannon entropy for each sample.
"""
return (-np.sum(weights * np.log(weights + EPSILON), axis=1)).astype(np.float32)
[docs]
def posterior_mean_signal(signals, weights, indices):
"""Compute posterior mean signal from neighbors.
Parameters
----------
signals : ndarray (N_lib, M)
Library signals.
weights : ndarray (N_query, K)
Posterior weights.
indices : ndarray (N_query, K)
Neighbor indices.
Returns
-------
mean_signal : ndarray (N_query, M)
Posterior mean signals.
"""
n_query = indices.shape[0]
n_grad = signals.shape[1]
k = indices.shape[1]
result = np.zeros((n_query, n_grad), dtype=np.float32)
for kk in range(k):
result += weights[:, kk : kk + 1] * signals[indices[:, kk]]
return result
[docs]
def posterior_odf(odfs, weights, indices, n_dirs):
"""Compute posterior ODF from neighbors.
Parameters
----------
odfs : ndarray (N_lib, D)
Library ODFs.
weights : ndarray (N_query, K)
Posterior weights.
indices : ndarray (N_query, K)
Neighbor indices.
n_dirs : int
Number of sphere directions.
Returns
-------
odf : ndarray (N_query, D)
Posterior mean ODFs.
"""
n_query = indices.shape[0]
k = indices.shape[1]
result = np.zeros((n_query, n_dirs), dtype=np.float32)
for kk in range(k):
odf_k = odfs[indices[:, kk]].astype(np.float32)
odf_k /= np.max(odf_k, axis=1, keepdims=True) + EPSILON
result += weights[:, kk : kk + 1] * odf_k
result /= np.max(result, axis=1, keepdims=True) + EPSILON
return result
[docs]
def force_peaks(fitted_object, *, mask=None, sh_order=8):
"""Create a PeaksAndMetrics object from a FORCEFit or MultiVoxelFit.
Parameters
----------
fitted_object : FORCEFit or MultiVoxelFit
The result of model.fit().
mask : ndarray, optional
Optional brain mask.
sh_order : int, optional
Spherical harmonics order for the coefficients.
"""
from dipy.sims.force import default_sphere
labels = fitted_object.label
fracs = fitted_object.fracs
odf = fitted_object.odf
is_multi_voxel = labels.ndim > 1
if not is_multi_voxel:
# Single Voxel Case
p_out, p_ind, p_val = postprocess_peaks(
labels[None, :], default_sphere, fracs[None, :]
)
res_dirs, res_inds, res_vals = p_out[0], p_ind[0], p_val[0]
res_sh = (
sf_to_sh(odf, default_sphere, sh_order=sh_order)
if odf is not None
else None
)
else:
# Multi-Voxel / CLI Case
original_shape = labels.shape[:-1] # (X, Y, Z)
if mask is not None:
labels_to_proc = labels[mask]
fracs_to_proc = fracs[mask]
else:
labels_to_proc = labels.reshape(-1, labels.shape[-1])
fracs_to_proc = fracs.reshape(-1, fracs.shape[-1])
p_out, p_ind, p_val = postprocess_peaks(
labels_to_proc, default_sphere, fracs_to_proc
)
res_dirs = np.zeros((*original_shape, 5, 3), dtype=np.float32)
res_inds = np.full((*original_shape, 5), -1, dtype=np.int32)
res_vals = np.zeros((*original_shape, 5), dtype=np.float32)
if mask is not None:
res_dirs[mask] = p_out
res_inds[mask] = p_ind
res_vals[mask] = p_val
else:
res_dirs = p_out.reshape((*original_shape, 5, 3))
res_inds = p_ind.reshape((*original_shape, 5))
res_vals = p_val.reshape((*original_shape, 5))
res_sh = None
if odf is not None and np.issubdtype(
getattr(odf, "dtype", type(None)), np.floating
):
v_max = np.max(odf, axis=-1, keepdims=True)
v_min = np.min(odf, axis=-1, keepdims=True)
mask = v_max > 1.0
denom = (v_max - v_min) + 1e-12
normalized_odf = (odf - v_min) / denom
odf = np.where(mask, normalized_odf, odf)
res_sh = sf_to_sh(odf, default_sphere, sh_order=sh_order)
peaks = PeaksAndMetrics()
peaks.peak_dirs = res_dirs
peaks.peak_values = res_vals
peaks.peak_indices = res_inds
peaks.shm_coeff = res_sh
peaks.sphere = default_sphere
return peaks