diff --git a/.github/workflows/build-test-publish.yml b/.github/workflows/build-test-publish.yml index 5c7adfdde0..c5759cbae5 100644 --- a/.github/workflows/build-test-publish.yml +++ b/.github/workflows/build-test-publish.yml @@ -85,6 +85,15 @@ jobs: dependencies: "pre" steps: + - name: Free disk space + uses: jlumbroso/free-disk-space@v1.3.1 + with: + tool-cache: false + android: true + dotnet: true + haskell: true + large-packages: false + swap-storage: false - uses: actions/checkout@v5 - uses: actions/cache@v4 with: @@ -161,7 +170,7 @@ jobs: - uses: actions/cache@v4 with: path: ${{ env.TEST_DATA_HOME }} - key: data-cache-v2 + key: data-cache-v3 restore-keys: | data-cache- - name: Install test data @@ -206,6 +215,16 @@ jobs: datalad update -r --merge -d hcph-pilot_fieldmaps/ datalad get -r -J 2 -d hcph-pilot_fieldmaps/ hcph-pilot_fieldmaps/* + # ds006926 — MEDIC multi-echo mag+phase BOLD (sub-a01 only) + datalad install -r https://github.com/OpenNeuroDatasets/ds006926.git + datalad update -r --merge -d ds006926/ + datalad get -r -J 2 -d ds006926/ ds006926/sub-a01/func/sub-a01_task-VisMot_acq-tr1800_* + + # ds007637 — MEDIC multi-echo mag+phase BOLD (sub-04/ses-2 fracback only) + datalad install -r https://github.com/OpenNeuroDatasets/ds007637.git + datalad update -r --merge -d ds007637/ + datalad get -r -J 2 -d ds007637/ ds007637/sub-04/ses-2/func/sub-04_ses-2_task-fracback_acq-MBME_echo-*_part-mag_bold.nii.gz ds007637/sub-04/ses-2/func/sub-04_ses-2_task-fracback_acq-MBME_echo-*_part-phase_bold.nii.gz + - name: Set FreeSurfer variables run: | echo "FREESURFER_HOME=$HOME/.cache/freesurfer" >> $GITHUB_ENV diff --git a/.zenodo.json b/.zenodo.json index 9894144510..8f155af06d 100644 --- a/.zenodo.json +++ b/.zenodo.json @@ -135,6 +135,12 @@ "affiliation": "Department of Psychology, Stanford University, CA, USA", "name": "Russell A. Poldrack", "type": "Researcher" + }, + { + "orcid": "0000-0002-8787-0943", + "affiliation": "Department of Biomedical Engineering, Washington University in St. Louis, MO, USA", + "name": "Andrew Van", + "type": "Researcher" } ], "keywords": [ diff --git a/pyproject.toml b/pyproject.toml index 303b3cfc46..6daee63782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "scipy >= 1.10", "templateflow >= 23.1", "toml >= 0.10", + # The marker keeps Python 3.10 installs working — warpkit requires >= 3.11. + "warpkit >= 1.4.0; python_version >= '3.11'", ] dynamic = ["version"] diff --git a/sdcflows/cli/main.py b/sdcflows/cli/main.py index 9336903f31..dfeb0fb04b 100644 --- a/sdcflows/cli/main.py +++ b/sdcflows/cli/main.py @@ -70,6 +70,7 @@ def main(argv=None): layout=config.execution.layout, subject=subject, fmapless=config.workflow.fmapless, + no_medic=config.workflow.no_medic, logger=config.loggers.cli, ) diff --git a/sdcflows/cli/parser.py b/sdcflows/cli/parser.py index ed89836d1f..c8551d8c32 100644 --- a/sdcflows/cli/parser.py +++ b/sdcflows/cli/parser.py @@ -247,6 +247,14 @@ def _bids_filter(value): default=True, help='Allow fieldmap-less estimation', ) + g_outputs.add_argument( + '--no-medic', + action='store_true', + dest='no_medic', + default=False, + help='Disable MEDIC discovery (by default MEDIC takes priority for ' + 'complex multi-echo BOLD)', + ) g_outputs.add_argument( '--use-plugin', action='store', diff --git a/sdcflows/config.py b/sdcflows/config.py index 98a660c7dd..a0b6c3bb09 100644 --- a/sdcflows/config.py +++ b/sdcflows/config.py @@ -498,6 +498,8 @@ class workflow(_Config): """Level of analysis.""" fmapless = False """Allow fieldmap-less estimation""" + no_medic = False + """Disable MEDIC discovery (otherwise MEDIC takes priority for complex multi-echo BOLD)""" species = 'human' """Subject species to choose most appropriate template""" template_id = 'MNI152NLin2009cAsym' diff --git a/sdcflows/conftest.py b/sdcflows/conftest.py index cdacab6292..f9d03085f9 100644 --- a/sdcflows/conftest.py +++ b/sdcflows/conftest.py @@ -41,10 +41,17 @@ test_workdir = os.getenv('TEST_WORK_DIR') _sloppy_mode = os.getenv('TEST_PRODUCTION', 'off').lower() not in ('on', '1', 'true', 'yes', 'y') +# MEDIC fixtures live in full OpenNeuro trees (tens of thousands of JSON +# sidecars) but only a few files are actually fetched via ``datalad get``. +# Indexing those trees with ``BIDSLayout(derivatives=True)`` at collection +# time stalled CI past the 20-minute tox watchdog. The MEDIC tests reach +# their files via the ``datadir`` fixture directly, not via ``layouts``. +_SKIP_LAYOUTS = {'ds006926', 'ds007637'} + layouts = { p.name: BIDSLayout(str(p), validate=False, derivatives=True) for p in Path(test_data_env).glob('*') - if p.is_dir() + if p.is_dir() and p.name not in _SKIP_LAYOUTS } data_dir = Path(__file__).parent / 'tests' / 'data' @@ -128,3 +135,52 @@ def dsA_dir(): @pytest.fixture def sloppy_mode(): return _sloppy_mode + + +# MEDIC end-to-end fixtures, shared by the fit (``test_medic``) and apply +# (``test_dynamic``) test modules. A handful of timepoints is enough to +# exercise the full per-volume path; the source datasets ship 200+ volumes × +# 5 echoes × mag+phase, which OOM-kills CI runners when xdist schedules these +# in parallel. +_MEDIC_DATASETS = [ + pytest.param( + ( + 'ds007637', + 'sub-04/ses-2/func/sub-04_ses-2_task-fracback_acq-MBME_echo-*_part-mag_bold.nii.gz', + ), + id='ds007637', + ), + pytest.param( + ('ds006926', 'sub-a01/func/sub-a01_task-VisMot_acq-tr1800_echo-*_part-mag_bold.nii.gz'), + id='ds006926', + ), +] + + +@pytest.fixture +def medic_test_volumes(): + return 3 + + +@pytest.fixture(params=_MEDIC_DATASETS) +def medic_fixture(request): + """Yield ``(dataset, mag_glob_under_dataset)`` for each MEDIC fixture.""" + return request.param + + +@pytest.fixture +def truncate_to_volumes(): + """Return a helper that slices 4D NIfTIs down to ``volumes`` timepoints.""" + + def _truncate(in_files, volumes, dest): + out = [] + for f in in_files: + img = nibabel.load(str(f)) + if img.shape[-1] > volumes: + img = img.slicer[..., :volumes] + new = dest / f.name + img.to_filename(new) + out.append(new) + return out + + return _truncate diff --git a/sdcflows/fieldmaps.py b/sdcflows/fieldmaps.py index 149d584792..722deb6044 100644 --- a/sdcflows/fieldmaps.py +++ b/sdcflows/fieldmaps.py @@ -55,6 +55,7 @@ class EstimatorType(Enum): PHASEDIFF = auto() MAPPED = auto() ANAT = auto() + MEDIC = auto() MODALITIES = { @@ -75,6 +76,12 @@ class EstimatorType(Enum): 'T2w': EstimatorType.ANAT, } +# Estimator types that emit a per-volume 4D fieldmap on the EPI grid (and +# therefore do not produce B-spline coefficients). Add new dynamic methods +# here so consumers — ``init_fmap_preproc_wf`` in particular — pick them up +# without per-method branching. +_DYNAMIC_METHODS = frozenset({EstimatorType.MEDIC}) + def _type_setter(obj, attribute, value): """Make sure the type of estimation is not changed.""" @@ -88,6 +95,7 @@ def _type_setter(obj, attribute, value): EstimatorType.PHASEDIFF, EstimatorType.MAPPED, EstimatorType.ANAT, + EstimatorType.MEDIC, ): raise ValueError(f'Invalid estimation method type {value}.') @@ -338,6 +346,36 @@ def __attrs_post_init__(self): suffix_list = [f.suffix for f in self.sources] suffix_set = set(suffix_list) + # Fieldmap option 0: MEDIC — multi-echo phase + magnitude + # ``bold`` / ``epi`` sources tagged with the BIDS ``part-{phase,mag}`` + # entity. PEPOLAR uses ``dir-`` instead, so the part entity is the + # cleanest way to disambiguate. + parts = {f.entities.get('part') for f in self.sources} + medic_parts = parts & {'phase', 'mag'} + if suffix_set <= {'bold', 'epi', 'sbref'} and medic_parts: + # Any sources is ``part``-tagged: this is a MEDIC-shaped input. + # Reject incomplete sets explicitly rather than letting them slip + # through to the PEPOLAR branch and produce a confusing failure. + if parts != {'phase', 'mag'}: + raise ValueError( + 'MEDIC requires every source to be tagged ``part-mag`` or ' + '``part-phase``, with both present; got ' + f'parts={sorted(str(p) for p in parts)!r}.' + ) + phase_files = [f for f in self.sources if f.entities.get('part') == 'phase'] + mag_files = [f for f in self.sources if f.entities.get('part') == 'mag'] + if len(phase_files) < 2: + raise ValueError( + f'MEDIC requires at least two echoes of phase data; got {len(phase_files)}.' + ) + if len(phase_files) != len(mag_files): + raise ValueError( + f'MEDIC requires matched magnitude/phase pairs per echo; ' + f'got {len(phase_files)} phase and {len(mag_files)} ' + 'magnitude file(s).' + ) + self.method = EstimatorType.MEDIC + # Fieldmap option 1: actual field-mapping sequences fmap_types = suffix_set.intersection(('fieldmap', 'phasediff', 'phase1', 'phase2')) if len(fmap_types) > 1 and fmap_types - {'phase1', 'phase2'}: @@ -399,7 +437,7 @@ def __attrs_post_init__(self): > 1 ) - if _pepolar_estimation and not anat_types: + if self.method == EstimatorType.UNKNOWN and _pepolar_estimation and not anat_types: self.method = MODALITIES[pepolar_types.pop()] _pe = {f.metadata['PhaseEncodingDirection'] for f in self.sources} if len(_pe) == 1: @@ -455,6 +493,11 @@ def __attrs_post_init__(self): # special characters are not allowed. self.sanitized_id = re.sub(r'[^a-zA-Z0-9]', '_', self.bids_id) + @property + def is_dynamic(self) -> bool: + """The estimator emits a per-volume 4D fieldmap and no B-spline coefficients.""" + return self.method in _DYNAMIC_METHODS + def paths(self): """Return a tuple of paths that are sorted.""" return tuple(sorted(str(f.path) for f in self.sources)) @@ -502,6 +545,28 @@ def get_workflow(self, set_inputs=True, **kwargs): from .workflows.fit.syn import init_syn_sdc_wf self._wf = init_syn_sdc_wf(**kwargs) + elif self.method == EstimatorType.MEDIC: + from .workflows.fit.medic import init_medic_wf + + for f in self.sources: + if not f.path.is_file(): + raise FileNotFoundError( + f'File path <{f.path}> does not exist, ' + 'is a broken link, or it is not a file' + ) + + self._wf = init_medic_wf(**kwargs) + + if set_inputs: + phase_files = [f for f in self.sources if f.entities.get('part') == 'phase'] + mag_files = [f for f in self.sources if f.entities.get('part') == 'mag'] + # Order both lists by EchoTime so warpkit gets aligned echo + # series. BIDS does not guarantee echo entity == numeric order. + phase_files = sorted(phase_files, key=lambda f: f.metadata['EchoTime']) + mag_files = sorted(mag_files, key=lambda f: f.metadata['EchoTime']) + self._wf.inputs.inputnode.phase = [str(f.path.absolute()) for f in phase_files] + self._wf.inputs.inputnode.magnitude = [str(f.path.absolute()) for f in mag_files] + self._wf.inputs.inputnode.metadata = [f.metadata for f in phase_files] return self._wf diff --git a/sdcflows/interfaces/tests/test_warpkit.py b/sdcflows/interfaces/tests/test_warpkit.py new file mode 100644 index 0000000000..c0d859ce24 --- /dev/null +++ b/sdcflows/interfaces/tests/test_warpkit.py @@ -0,0 +1,75 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Tests for the warpkit nipype interface wrappers. + +These tests check spec shape and small helpers; the actual ``_run_interface`` +methods require :mod:`warpkit`, which is an optional dependency. +""" + +import pytest + +from sdcflows.interfaces import warpkit as wk + + +def test_warpkit_base_interface_pkg(): + """All warpkit interfaces share the ``warpkit`` library tag. + + This lets nipype's ``LibraryBaseInterface`` emit a single, consistent + "warpkit not installed" message rather than per-class noise. + """ + assert wk.WarpkitBaseInterface._pkg == 'warpkit' + for cls in (wk.UnwrapPhase, wk.ComputeFieldmap): + assert issubclass(cls, wk.WarpkitBaseInterface) + + +@pytest.mark.parametrize( + 'cls,expected_inputs,expected_outputs', + [ + ( + wk.UnwrapPhase, + {'phase', 'magnitude', 'echo_times'}, + {'unwrapped', 'masks'}, + ), + ( + wk.ComputeFieldmap, + {'unwrapped', 'magnitude', 'masks', 'border_filt', 'svd_filt'}, + {'fieldmap_native', 'displacement_map', 'fieldmap'}, + ), + ], +) +def test_interface_spec_traits(cls, expected_inputs, expected_outputs): + """Each interface declares the expected input/output traits.""" + iface = cls() + assert expected_inputs <= set(iface.inputs.copyable_trait_names()) + assert expected_outputs <= set(iface.output_spec().copyable_trait_names()) + + +def test_compute_fieldmap_border_filt_default(): + """``border_filt`` defaults to ``(1, 5)``. + + Regression test for an upstream ``traits.Tuple`` quirk where the outer + ``default`` kwarg silently lost to inner ``Int()`` zeros, collapsing the + SVD border filter and clipping the dynamic fieldmap footprint. + """ + iface = wk.ComputeFieldmap() + assert tuple(iface.inputs.border_filt) == (1, 5) diff --git a/sdcflows/interfaces/warpkit.py b/sdcflows/interfaces/warpkit.py new file mode 100644 index 0000000000..1a7d3a3e14 --- /dev/null +++ b/sdcflows/interfaces/warpkit.py @@ -0,0 +1,187 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Nipype interfaces wrapping :mod:`warpkit.api`. + +`warpkit `__ implements MEDIC +(Multi-Echo DIstortion Correction). This module exposes the two MEDIC +stages SDCFlows actually drives — phase unwrapping and fieldmap +computation — as :class:`~nipype.interfaces.base.SimpleInterface` +subclasses calling :mod:`warpkit.api` in-process. ``warpkit`` is an +optional dependency; :class:`~nipype.interfaces.base.LibraryBaseInterface` +emits a clean "package not installed" error if it is missing at runtime. +""" + +import os + +from nipype.interfaces.base import ( + BaseInterfaceInputSpec, + File, + InputMultiObject, + LibraryBaseInterface, + OutputMultiObject, + SimpleInterface, + TraitedSpec, + isdefined, + traits, +) + +PE_DIRECTIONS = ('i', 'j', 'k', 'i-', 'j-', 'k-', 'x', 'y', 'z', 'x-', 'y-', 'z-') + + +class WarpkitBaseInterface(LibraryBaseInterface): + """Base for all warpkit-backed interfaces.""" + + _pkg = 'warpkit' + + +# --------------------------------------------------------------------------- +# UnwrapPhase — ROMEO multi-echo phase unwrapping +# --------------------------------------------------------------------------- + + +class _UnwrapPhaseInputSpec(BaseInterfaceInputSpec): + phase = InputMultiObject(File(exists=True), mandatory=True) + magnitude = InputMultiObject(File(exists=True), mandatory=True) + echo_times = traits.List(traits.Float, xor=['metadata']) + metadata = InputMultiObject(File(exists=True), xor=['echo_times']) + out_prefix = traits.Str('unwrap', usedefault=True) + n_cpus = traits.Int(4, usedefault=True) + wrap_limit = traits.Bool(False, usedefault=True) + debug = traits.Bool(False, usedefault=True) + + +class _UnwrapPhaseOutputSpec(TraitedSpec): + unwrapped = OutputMultiObject(File(exists=True), desc='unwrapped phase per echo') + masks = File(exists=True, desc='per-frame masks NIfTI') + + +class UnwrapPhase(WarpkitBaseInterface, SimpleInterface): + """ROMEO multi-echo phase unwrapping (:func:`warpkit.api.unwrap_phase`).""" + + input_spec = _UnwrapPhaseInputSpec + output_spec = _UnwrapPhaseOutputSpec + + def _run_interface(self, runtime): + from warpkit.api import unwrap_phase + + out_prefix = os.path.join(runtime.cwd, self.inputs.out_prefix) + try: + result = unwrap_phase( + phase=list(self.inputs.phase), + magnitude=list(self.inputs.magnitude), + out_prefix=out_prefix, + tes=(list(self.inputs.echo_times) if isdefined(self.inputs.echo_times) else None), + metadata=(list(self.inputs.metadata) if isdefined(self.inputs.metadata) else None), + n_cpus=self.inputs.n_cpus, + wrap_limit=self.inputs.wrap_limit, + debug=self.inputs.debug, + ) + except ValueError as e: + raise RuntimeError(str(e)) from e + + self._results['unwrapped'] = [str(p) for p in result.unwrapped] + self._results['masks'] = str(result.masks) + return runtime + + +# --------------------------------------------------------------------------- +# ComputeFieldmap — post-unwrap stage of MEDIC +# --------------------------------------------------------------------------- + + +class _ComputeFieldmapInputSpec(BaseInterfaceInputSpec): + unwrapped = InputMultiObject( + File(exists=True), + mandatory=True, + desc='unwrapped phase per echo (output of UnwrapPhase)', + ) + magnitude = InputMultiObject(File(exists=True), mandatory=True) + masks = File(exists=True, mandatory=True, desc='per-frame masks (output of UnwrapPhase)') + echo_times = traits.List(traits.Float, xor=['metadata']) + total_readout_time = traits.Float(xor=['metadata']) + phase_encoding_direction = traits.Enum(*PE_DIRECTIONS, xor=['metadata']) + metadata = InputMultiObject( + File(exists=True), + xor=['echo_times', 'total_readout_time', 'phase_encoding_direction'], + ) + out_prefix = traits.Str('fieldmap', usedefault=True) + # NOTE: `traits.Tuple(Int(), Int(), default=(1, 5))` is silently ignored — + # the inner Int()s default to 0, and the outer `default` kwarg loses. + # That collapses the border-filter to 0 SVD components, which zeros the + # mask==1 ring in warpkit.unwrap.svd_filtering and makes the dynamic + # fieldmap appear hard-brain-masked. Pass defaults to the inner Ints. + border_filt = traits.Tuple( + traits.Int(1), + traits.Int(5), + usedefault=True, + desc='SVD components for the two-pass border filter', + ) + svd_filt = traits.Int(10, usedefault=True) + n_cpus = traits.Int(4, usedefault=True) + + +class _ComputeFieldmapOutputSpec(TraitedSpec): + fieldmap_native = File(exists=True) + displacement_map = File(exists=True) + fieldmap = File(exists=True) + + +class ComputeFieldmap(WarpkitBaseInterface, SimpleInterface): + """Post-unwrap MEDIC stage (:func:`warpkit.api.compute_fieldmap`).""" + + input_spec = _ComputeFieldmapInputSpec + output_spec = _ComputeFieldmapOutputSpec + + def _run_interface(self, runtime): + from warpkit.api import compute_fieldmap + + out_prefix = os.path.join(runtime.cwd, self.inputs.out_prefix) + try: + result = compute_fieldmap( + unwrapped=list(self.inputs.unwrapped), + magnitude=list(self.inputs.magnitude), + masks=self.inputs.masks, + out_prefix=out_prefix, + tes=(list(self.inputs.echo_times) if isdefined(self.inputs.echo_times) else None), + total_readout_time=( + self.inputs.total_readout_time + if isdefined(self.inputs.total_readout_time) + else None + ), + phase_encoding_direction=( + self.inputs.phase_encoding_direction + if isdefined(self.inputs.phase_encoding_direction) + else None + ), + metadata=(list(self.inputs.metadata) if isdefined(self.inputs.metadata) else None), + border_filt=tuple(self.inputs.border_filt), + svd_filt=self.inputs.svd_filt, + n_cpus=self.inputs.n_cpus, + ) + except ValueError as e: + raise RuntimeError(str(e)) from e + + self._results['fieldmap_native'] = str(result.fieldmap_native) + self._results['displacement_map'] = str(result.displacement_map) + self._results['fieldmap'] = str(result.fieldmap) + return runtime diff --git a/sdcflows/tests/test_fieldmaps.py b/sdcflows/tests/test_fieldmaps.py index 383376b988..aefba6fbb6 100644 --- a/sdcflows/tests/test_fieldmaps.py +++ b/sdcflows/tests/test_fieldmaps.py @@ -333,6 +333,69 @@ def test_FieldmapEstimation_missing_files(tmpdir, dsA_dir): ) +def _make_medic_files(dest, source_path, specs): + """Copy a NIfTI into part/echo-tagged filenames under ``dest``. + + Returns the list of :class:`~sdcflows.fieldmaps.FieldmapFile`s built from + those paths with the minimum metadata required for MEDIC source files. + """ + base_meta = {'PhaseEncodingDirection': 'j', 'TotalReadoutTime': 0.05} + files = [] + for echo, part in specs: + name = f'sub-01_task-rest_echo-{echo}_part-{part}_bold.nii.gz' + dst = dest / name + shutil.copy(str(source_path), str(dst)) + files.append( + fm.FieldmapFile( + str(dst), + metadata={**base_meta, 'EchoTime': 0.01 * echo}, + ) + ) + return files + + +def test_FieldmapEstimation_MEDIC_requires_both_parts(tmp_path, dsA_dir): + """MEDIC sources tagged only ``part-phase`` must reject in the part guard.""" + src = dsA_dir / 'sub-01' / 'func' / 'sub-01_task-rest_bold.nii.gz' + files = _make_medic_files(tmp_path, src, [(1, 'phase'), (2, 'phase')]) + with pytest.raises(ValueError, match='MEDIC requires every source'): + fm.FieldmapEstimation(files) + + +def test_FieldmapEstimation_MEDIC_single_echo(tmp_path, dsA_dir): + """One mag+phase pair (single echo) must reject in the echo-count guard.""" + src = dsA_dir / 'sub-01' / 'func' / 'sub-01_task-rest_bold.nii.gz' + files = _make_medic_files(tmp_path, src, [(1, 'mag'), (1, 'phase')]) + with pytest.raises(ValueError, match='at least two echoes of phase'): + fm.FieldmapEstimation(files) + + +def test_FieldmapEstimation_MEDIC_mismatched_pairs(tmp_path, dsA_dir): + """Unequal magnitude / phase echo counts must reject in the pairing guard.""" + src = dsA_dir / 'sub-01' / 'func' / 'sub-01_task-rest_bold.nii.gz' + files = _make_medic_files(tmp_path, src, [(1, 'phase'), (2, 'phase'), (1, 'mag')]) + with pytest.raises(ValueError, match='matched magnitude/phase pairs'): + fm.FieldmapEstimation(files) + + +def test_FieldmapEstimation_is_dynamic(tmp_path, dsA_dir): + """``is_dynamic`` flags MEDIC (4D fmap on EPI grid) and not the static estimators.""" + src = dsA_dir / 'sub-01' / 'func' / 'sub-01_task-rest_bold.nii.gz' + medic_files = _make_medic_files( + tmp_path, src, [(1, 'mag'), (1, 'phase'), (2, 'mag'), (2, 'phase')] + ) + assert fm.FieldmapEstimation(medic_files).is_dynamic is True + + sub_dir = dsA_dir / 'sub-01' + phasediff = fm.FieldmapEstimation( + [ + sub_dir / 'fmap/sub-01_phase1.nii.gz', + sub_dir / 'fmap/sub-01_phase2.nii.gz', + ] + ) + assert phasediff.is_dynamic is False + + def test_FieldmapFile_filename(tmp_path, dsA_dir): datadir = tmp_path / 'phasediff' datadir.mkdir(exist_ok=True) diff --git a/sdcflows/transform.py b/sdcflows/transform.py index d2611e8ba9..21255b6f8a 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -106,17 +106,40 @@ def _sdc_unwarp( prefilter=prefilter, ) - # The Jacobian determinant image is the amount of stretching in the PE direction. - # Using central differences accounts for the shift in neighboring voxels. - # The full Jacobian at each voxel would be a 3x3 matrix, but because there is - # only warping in one direction, we end up with a diagonal matrix with two 1s. - # The following is the other entry at each voxel, and hence the determinant. if jacobian: - resampled *= 1 + np.gradient(vsm, axis=pe_info[0]) + resampled *= fieldmap_jacobian(vsm, pe_info[0]) return resampled +def fieldmap_jacobian( + vsm: np.ndarray, + pe_axis: int, +) -> np.ndarray: + r""" + Voxel-wise Jacobian determinant of a one-axis (PE) EPI distortion. + + EPI distortion only acts along the phase-encoding axis, so the full + 3×3 Jacobian collapses to a diagonal with two 1s and one nontrivial + entry: :math:`|J| \approx 1 + \partial(\mathrm{VSM})/\partial(\mathrm{PE})` + (central differences capture the relative shift of neighboring voxels). + Multiplying a resampled EPI by this scalar field preserves total signal + through the regions that compress and expand under unwarping. + + Parameters + ---------- + vsm : :class:`numpy.ndarray` + Voxel shift map, :math:`\mathrm{VSM} = f_{\mathrm{Hz}}\cdot t_{\mathrm{ro}}`, + in voxel units. 3D ``(I, J, K)`` for a static fieldmap, or 4D + ``(I, J, K, T)`` for a per-volume dynamic fieldmap (e.g. MEDIC). The + readout time already carries PE polarity and any data-orientation + flip, so the sign must be applied before computing the VSM. + pe_axis : :obj:`int` + Spatial axis index (``0``, ``1`` or ``2``) along which the EPI distorts. + """ + return 1 + np.gradient(vsm, axis=pe_axis) + + async def worker( data: np.ndarray, coordinates: np.ndarray, @@ -159,7 +182,9 @@ async def unwarp_parallel( An array of shape (3, I, J, K) array providing the voxel (index) coordinates of the reference image (i.e., interpolated points) before SDC/HMC. fmap_hz : :obj:`~numpy.ndarray` - An array of shape (I, J, K) containing the displacement of each voxel in voxel units. + The :math:`B_0` field in Hz. 3D ``(I, J, K)`` for a field shared across + all volumes (static estimators), or 4D ``(I, J, K, T)`` for one field + per EPI volume (dynamic estimators, e.g. MEDIC). pe_info : :obj:`tuple` of (:obj:`int`, :obj:`float`) A tuple containing the index of the phase-encoding axis in the data array and the readout time (including sign, if displacements must be reversed) @@ -201,21 +226,29 @@ async def unwarp_parallel( if fulldataset.ndim == 3: fulldataset = fulldataset[..., np.newaxis] - func = partial( - _sdc_unwarp, - jacobian=jacobian, - fmap_hz=fmap_hz, - output_dtype=output_dtype, - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - ) + n_volumes = fulldataset.shape[-1] - # Create a worker task for each chunk + # Normalize to a per-frame 4D field: a 3D field is shared across all volumes + # (static estimators), so broadcast it; a 4D field already carries one Hz + # volume per EPI frame (dynamic estimators, e.g. MEDIC). After this, frame + # selection is a single, branchless ``fmap_hz[..., volid]`` below. + if fmap_hz.ndim == 3: + fmap_hz = np.broadcast_to(fmap_hz[..., np.newaxis], (*fmap_hz.shape, n_volumes)) + + # Create a worker task for each volume tasks = [] for volid, volume in enumerate(np.rollaxis(fulldataset, -1, 0)): xfm = None if xfms is None else xfms[volid] + func = partial( + _sdc_unwarp, + jacobian=jacobian, + fmap_hz=fmap_hz[..., volid], + output_dtype=output_dtype, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + ) # IMPORTANT - the coordinates array must be copied every time anew per thread task = asyncio.create_task( @@ -245,10 +278,15 @@ class B0FieldTransform: coeffs = attr.ib(default=None) """B-Spline coefficients (one value per control point).""" - mapped = attr.ib(default=None, init=False) + mapped = attr.ib(default=None) """ - A cache of the interpolated field in Hz (i.e., the fieldmap *mapped* on to the - target image we want to correct). + The fieldmap in Hz, mapped onto the target image we want to correct. + + Populated by :meth:`fit` from :attr:`coeffs` (static B-Spline estimators), + or supplied directly at construction when the field is already on the + target grid (dynamic estimators, e.g. MEDIC). May be 3D ``(I, J, K)`` — + one field shared by every EPI volume — or 4D ``(I, J, K, T)`` — one field + per volume. """ def fit( @@ -460,61 +498,31 @@ def apply( # Make sure the data array has all cosines positive (i.e., no axes are flipped) moving, axcodes = ensure_positive_cosines(moving) - if self.mapped is not None: - warn( - 'The fieldmap has been already fit, the user is responsible for ' - 'ensuring the parameters of the EPI target are consistent.', - stacklevel=2, + if self.coeffs is None and self.mapped is None: + raise ValueError( + 'B0FieldTransform needs either B-Spline coefficients (coeffs) to fit, ' + 'or a pre-gridded fieldmap in Hz (mapped) to resample with.' ) - else: - # Generate warp field (before ensuring positive cosines) - self.fit(moving, xfm_data2fmap=xfm_data2fmap, approx=approx) - - # Squeeze non-spatial dimensions - newshape = moving.shape[:3] + tuple(dim for dim in moving.shape[3:] if dim > 1) - data = nb.arrayproxy.reshape_dataobj(moving.dataobj, newshape) - ndim = min(data.ndim, 3) - n_volumes = data.shape[3] if data.ndim == 4 else 1 - output_dtype = output_dtype or moving.header.get_data_dtype() - - # Prepare input parameters - if isinstance(pe_dir, str): - pe_dir = [pe_dir] - - if isinstance(ro_time, float): - ro_time = [ro_time] - - if n_volumes > 1 and len(pe_dir) == 1: - pe_dir *= n_volumes - - if n_volumes > 1 and len(ro_time) == 1: - ro_time *= n_volumes - - pe_info = [] - for vol_pe_dir, vol_ro_time in zip(pe_dir, ro_time, strict=False): - pe_axis = 'ijk'.index(vol_pe_dir[0]) - # Displacements are reversed if either is true (after ensuring positive cosines) - flip = (axcodes[pe_axis] in 'LPI') ^ vol_pe_dir.endswith('-') - - pe_info.append((pe_axis, -vol_ro_time if flip else vol_ro_time)) - - # Reference image's voxel coordinates (in voxel units) - voxcoords = ( - nt.linear.Affine(reference=moving) - .reference.ndindex.T.reshape((ndim, *data.shape[:ndim])) - .astype('float32') - ) - # Convert head-motion transforms to voxel-to-voxel: + if self.mapped is not None and self.coeffs is None: + # Pre-gridded field (e.g., MEDIC dynamic): already on the target + # grid, nothing to reconstruct — just normalize its orientation. + fmap_img, _ = ensure_positive_cosines(self.mapped) + fmap_hz = np.asanyarray(fmap_img.dataobj, dtype='float32') + else: + if self.mapped is not None: + warn( + 'The fieldmap has been already fit, the user is responsible for ' + 'ensuring the parameters of the EPI target are consistent.', + stacklevel=2, + ) + else: + # Generate warp field (before ensuring positive cosines) + self.fit(moving, xfm_data2fmap=xfm_data2fmap, approx=approx) + fmap_hz = self.mapped.get_fdata(dtype='float32') + + # Head-motion compensation is not yet wired through the unwarp. if xfms is not None: - # if len(xfms) != n_volumes: - # raise RuntimeError( - # f"Number of head-motion estimates ({len(xfms)}) does not match the " - # f"number of volumes ({n_volumes})" - # ) - # vox2ras = moving.affine.copy() - # ras2vox = np.linalg.inv(vox2ras) - # xfms = [ras2vox @ xfm @ vox2ras for xfm in xfms] xfms = None warn( 'Head-motion compensating (realignment) transforms are ignored when applying ' @@ -523,31 +531,23 @@ def apply( stacklevel=1, ) - # Resample - resampled = asyncio.run( - unwarp_parallel( - data, - voxcoords, - self.mapped.get_fdata(dtype='float32'), # fieldmap in Hz - pe_info, - xfms, - jacobian, - output_dtype='float32', - order=order, - mode=mode, - cval=cval, - prefilter=prefilter, - max_concurrent=num_threads or min(os.cpu_count(), 12), - ) + return _resample_with_fieldmap( + moving, + axcodes, + fmap_hz, + pe_dir, + ro_time, + xfms=xfms, + jacobian=jacobian, + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + output_dtype=output_dtype, + num_threads=num_threads, + allow_negative=allow_negative, ) - if not allow_negative: - resampled[resampled < 0] = cval - - moved = moving.__class__(resampled, moving.affine, moving.header) - moved.header.set_data_dtype(output_dtype) - return reorient_image(moved, axcodes) - def to_displacements(self, ro_time, pe_dir, itk_format=True): """ Generate a NIfTI file containing a displacements field transform compatible with ITK/ANTs. @@ -572,6 +572,96 @@ def to_displacements(self, ro_time, pe_dir, itk_format=True): return fmap_to_disp(self.mapped, ro_time, pe_dir, itk_format=itk_format) +def _resample_with_fieldmap( + moving, + axcodes, + fmap_hz: np.ndarray, + pe_dir, + ro_time, + *, + xfms: Sequence[np.ndarray] | None = None, + jacobian: bool = True, + order: int = 3, + mode: str = 'constant', + cval: float = 0.0, + prefilter: bool = True, + output_dtype: str | np.dtype | None = None, + num_threads: int | None = None, + allow_negative: bool = False, +): + """Resample ``moving`` through an on-grid Hz fieldmap (3D or 4D). + + Shared core of :meth:`B0FieldTransform.apply`. The caller is responsible + for producing ``fmap_hz`` already on the ``moving`` grid — B-spline + reconstruction plus coregistration for static estimators, or a pre-gridded + per-frame field for dynamic estimators — and for ensuring ``moving`` has + positive cosines. + + A 3D ``fmap_hz`` is shared across all EPI volumes; a 4D ``fmap_hz`` carries + one Hz volume per EPI frame and must match the number of volumes. + """ + # Squeeze non-spatial dimensions + newshape = moving.shape[:3] + tuple(dim for dim in moving.shape[3:] if dim > 1) + data = np.asanyarray(nb.arrayproxy.reshape_dataobj(moving.dataobj, newshape)) + ndim = min(data.ndim, 3) + n_volumes = data.shape[3] if data.ndim == 4 else 1 + output_dtype = output_dtype or moving.header.get_data_dtype() + + if fmap_hz.ndim == 4 and fmap_hz.shape[-1] != n_volumes: + raise ValueError( + f'Dynamic fieldmap frame count ({fmap_hz.shape[-1]}) does not match ' + f'EPI volumes ({n_volumes}).' + ) + + # Prepare input parameters + if isinstance(pe_dir, str): + pe_dir = [pe_dir] + if isinstance(ro_time, (int, float)): + ro_time = [float(ro_time)] + if n_volumes > 1 and len(pe_dir) == 1: + pe_dir = pe_dir * n_volumes + if n_volumes > 1 and len(ro_time) == 1: + ro_time = ro_time * n_volumes + + pe_info = [] + for vol_pe_dir, vol_ro_time in zip(pe_dir, ro_time, strict=False): + pe_axis = 'ijk'.index(vol_pe_dir[0]) + # Displacements are reversed if either is true (after ensuring positive cosines) + flip = (axcodes[pe_axis] in 'LPI') ^ vol_pe_dir.endswith('-') + pe_info.append((pe_axis, -vol_ro_time if flip else vol_ro_time)) + + # Reference image's voxel coordinates (in voxel units) + voxcoords = ( + nt.linear.Affine(reference=moving) + .reference.ndindex.T.reshape((ndim, *data.shape[:ndim])) + .astype('float32') + ) + + resampled = asyncio.run( + unwarp_parallel( + data, + voxcoords, + fmap_hz, + pe_info, + xfms, + jacobian, + output_dtype='float32', + order=order, + mode=mode, + cval=cval, + prefilter=prefilter, + max_concurrent=num_threads or min(os.cpu_count(), 12), + ) + ) + + if not allow_negative: + resampled[resampled < 0] = cval + + moved = moving.__class__(resampled, moving.affine, moving.header) + moved.header.set_data_dtype(output_dtype) + return reorient_image(moved, axcodes) + + def fmap_to_disp(fmap_nii, ro_time, pe_dir, itk_format=True): """ Convert a fieldmap in Hz into an ITK/ANTs-compatible displacements field. diff --git a/sdcflows/utils/tests/test_wrangler.py b/sdcflows/utils/tests/test_wrangler.py index d933625eca..b944d9f5fd 100644 --- a/sdcflows/utils/tests/test_wrangler.py +++ b/sdcflows/utils/tests/test_wrangler.py @@ -433,6 +433,76 @@ def gen_layout(bids_dir, database_dir=None): } +def _build_medic_skeleton(*, intent: str | None = 'intended_for'): + """Generate a 3-session × 3-echo × {mag,phase} BIDS skeleton for MEDIC. + + MEDIC is discovered only from BIDS intent metadata, never from file + structure alone. ``intent`` selects how that intent is expressed on each + complex BOLD: + + * ``"b0_identifier"`` — the BIDS-RECOMMENDED route: a self-referential + ``B0FieldIdentifier``/``B0FieldSource`` (the pattern BIDS endorses for + images that estimate their own B0 field). + * ``"intended_for"`` — the legacy route: ``IntendedFor`` listing the run's + 6 mag/phase siblings. + * ``None`` — no intent metadata at all, to confirm MEDIC does **not** fire + on structure alone. + """ + echo_times = {'1': 0.0142, '2': 0.03893, '3': 0.06366} + sessions = [] + for ses in ('01', '02', '03'): + intended_for = [ + ( + f'bids::sub-01/ses-{ses}/func/' + f'sub-01_ses-{ses}_task-rest_echo-{echo}_part-{part}_bold.nii.gz' + ) + for echo in echo_times + for part in ('mag', 'phase') + ] + func = [] + for echo, te in echo_times.items(): + for part in ('mag', 'phase'): + metadata = { + 'EchoTime': te, + 'RepetitionTime': 0.8, + 'TotalReadoutTime': 0.5, + 'PhaseEncodingDirection': 'j', + } + if intent == 'intended_for': + metadata['IntendedFor'] = intended_for + elif intent == 'b0_identifier': + # Every mag+phase echo is an *input* to the estimation, so + # all carry B0FieldIdentifier. Only the magnitude echoes are + # *corrected* analysis targets, so B0FieldSource sits on mag + # alone (the pepolar-style self-correction pattern). + b0_id = f'medic_ses{ses}' + metadata['B0FieldIdentifier'] = b0_id + if part == 'mag': + metadata['B0FieldSource'] = b0_id + func.append( + { + 'task': 'rest', + 'echo': echo, + 'part': part, + 'suffix': 'bold', + 'metadata': metadata, + } + ) + sessions.append( + { + 'session': ses, + 'anat': [{'suffix': 'T1w', 'metadata': {'EchoTime': 1}}], + 'func': func, + } + ) + return {'01': sessions} + + +medic = _build_medic_skeleton(intent='intended_for') +medic_b0_identifier = _build_medic_skeleton(intent='b0_identifier') +medic_no_intent = _build_medic_skeleton(intent=None) + + filters = { 'fmap': { 'datatype': 'fmap', @@ -440,6 +510,7 @@ def gen_layout(bids_dir, database_dir=None): }, 't1w': {'datatype': 'anat', 'session': '01', 'suffix': 'T1w'}, 'bold': {'datatype': 'func', 'session': '01', 'suffix': 'bold'}, + 'medic': {'datatype': ['fmap', 'func'], 'session': '01'}, } @@ -449,6 +520,7 @@ def gen_layout(bids_dir, database_dir=None): ('pepolar', pepolar, 1, 'fmap'), ('pepolar_b0ids', pepolar_b0ids, 1, 'bold'), ('phasediff', phasediff, 1, 'fmap'), + ('medic', medic, 1, 'medic'), ], ) def test_wrangler_filter(tmpdir, name, skeleton, estimations, bids_filters): @@ -466,6 +538,7 @@ def test_wrangler_filter(tmpdir, name, skeleton, estimations, bids_filters): ('pepolar', pepolar, 5, True), ('pepolar_b0ids', pepolar_b0ids, 2, False), ('phasediff', phasediff, 3, True), + ('medic', medic, 3, False), ], ) @pytest.mark.parametrize( @@ -496,6 +569,120 @@ def test_wrangler_URIs(tmpdir, name, skeleton, session, estimations, total_estim clear_registry() +def test_wrangler_medic_no_intent_does_not_fire(tmp_path): + """Structure alone must not trigger MEDIC. + + A complex multi-echo BOLD with no ``B0FieldIdentifier`` and no + ``IntendedFor`` carries no BIDS intent, so MEDIC must not be discovered + (``fmapless=False`` rules out the ANAT fallback, isolating the MEDIC path). + """ + bids_dir = str(tmp_path / 'medic_no_intent') + generate_bids_skeleton(bids_dir, medic_no_intent) + layout = gen_layout(bids_dir) + est = find_estimators(layout=layout, subject='01', fmapless=False) + assert est == [] + clear_registry() + + +@pytest.mark.parametrize( + ('skeleton', 'no_medic', 'expected'), + [ + # A: BIDS-recommended route — self-referential B0FieldIdentifier. + (medic_b0_identifier, False, 3), + # B: legacy route — IntendedFor on the complex BOLD sidecars. + (medic, False, 3), + # C/D: ``no_medic`` suppresses discovery via either route. + (medic_b0_identifier, True, 0), + (medic, True, 0), + ], + ids=['b0-identifier', 'intended-for', 'no_medic-b0', 'no_medic-intended-for'], +) +def test_wrangler_medic_trigger(tmp_path, skeleton, no_medic, expected): + """Metadata-driven MEDIC discovery and the ``no_medic`` override. + + * **b0-identifier**: complex BOLD carries a self-referential + ``B0FieldIdentifier`` (the BIDS-recommended route); discovered via Step 1. + * **intended-for**: complex BOLD carries ``IntendedFor`` (legacy route); + discovered via the dedicated MEDIC block. + * **no_medic-***: ``no_medic=True`` skips MEDIC via either route, so + nothing fires (``fmapless=False`` rules out the ANAT fallback). + """ + bids_dir = str(tmp_path / 'medic_trigger') + generate_bids_skeleton(bids_dir, skeleton) + layout = gen_layout(bids_dir) + estimators = find_estimators( + layout=layout, + subject='01', + fmapless=False, + no_medic=no_medic, + ) + assert len(estimators) == expected + for estimator in estimators: + assert estimator.method.name == 'MEDIC' + # 3 echoes × {mag, phase} per session. + assert len(estimator.sources) == 6 + clear_registry() + + +def test_wrangler_medic_ordered_first(tmp_path): + """MEDIC precedes static estimators in the returned list. + + Consumers (fMRIPrep) walk this list and select the first applicable + estimator per target, so a dynamic MEDIC estimator must come before a + coexisting static fieldmap. Here a PEPOLAR pair and a complex multi-echo + BOLD each carry their own ``B0FieldIdentifier``; both are discovered and + MEDIC must sort first regardless of ``B0FieldIdentifier`` iteration order. + """ + skeleton = { + '01': [ + { + 'anat': [{'suffix': 'T1w', 'metadata': {'EchoTime': 1}}], + 'func': [ + { + 'task': 'rest', + 'run': run, + 'suffix': 'bold', + 'metadata': { + 'RepetitionTime': 0.8, + 'TotalReadoutTime': 0.5, + 'PhaseEncodingDirection': ped, + 'B0FieldIdentifier': 'pepolar1', + 'B0FieldSource': 'pepolar1', + }, + } + for run, ped in ((1, 'j'), (2, 'j-')) + ] + + [ + { + 'task': 'medic', + 'echo': echo, + 'part': part, + 'suffix': 'bold', + 'metadata': { + 'EchoTime': te, + 'RepetitionTime': 0.8, + 'TotalReadoutTime': 0.5, + 'PhaseEncodingDirection': 'j', + 'B0FieldIdentifier': 'medic1', + **({'B0FieldSource': 'medic1'} if part == 'mag' else {}), + }, + } + for echo, te in (('1', 0.0142), ('2', 0.0389)) + for part in ('mag', 'phase') + ], + } + ] + } + bids_dir = str(tmp_path / 'medic_first') + generate_bids_skeleton(bids_dir, skeleton) + layout = gen_layout(bids_dir) + estimators = find_estimators(layout=layout, subject='01', fmapless=False) + methods = [e.method.name for e in estimators] + assert methods[0] == 'MEDIC', methods + assert 'PEPOLAR' in methods + clear_registry() + + def test_single_reverse_pedir(tmp_path): bids_dir = tmp_path / 'bids' generate_bids_skeleton(bids_dir, pepolar) diff --git a/sdcflows/utils/wrangler.py b/sdcflows/utils/wrangler.py index f22725167f..dfd8dd9fb4 100644 --- a/sdcflows/utils/wrangler.py +++ b/sdcflows/utils/wrangler.py @@ -74,6 +74,7 @@ def find_estimators( sessions: list[str] | None = None, fmapless: bool | set = True, force_fmapless: bool = False, + no_medic: bool = False, logger: logging.Logger | None = None, bids_filters: dict | None = None, anat_suffix: str | list[str] = 'T1w', @@ -103,6 +104,15 @@ def find_estimators( force_fmapless : :obj:`bool` When some other fieldmap estimation methods have been found, fieldmap-less estimation will be skipped except if ``force_fmapless`` is ``True``. + no_medic : :obj:`bool` + Disable MEDIC discovery entirely. MEDIC is discovered like any other + fieldmap — only from BIDS intent metadata, never from file structure + alone: a complex multi-echo BOLD becomes a MEDIC estimator when it + carries a ``B0FieldIdentifier`` (the self-referential pattern BIDS + endorses for images that estimate their own B0 field, as in + ``pepolar``) or, on legacy datasets, ``IntendedFor``. Setting + ``no_medic=True`` skips those MEDIC estimators (other fieldmaps are + unaffected). logger The logger used to relay messages. If not provided, one will be created. bids_filters @@ -119,6 +129,12 @@ def find_estimators( successfully been built (meaning, all necessary inputs and corresponding metadata are present in the given layout.) + The list is returned in a deterministic order: dynamic (MEDIC) + estimators first — an intentional priority so that consumers selecting + the first applicable estimator per target prefer MEDIC over a + coexisting static fieldmap — then the remaining estimators ordered by + ``bids_id``, with fieldmap-less (ANAT) estimators last. + Examples -------- Our ``ds000054`` dataset, created for *fMRIPrep*, only has one *phasediff* type of fieldmap @@ -326,7 +342,6 @@ def find_estimators( base_entities = { 'subject': subject, 'extension': ['.nii', '.nii.gz'], - 'part': ['mag', None], 'scope': 'raw', # Ensure derivatives are not captured } @@ -351,7 +366,10 @@ def find_estimators( # flatten lists from json (tupled in pybids for hashing), then unique b0_ids = reduce( set.union, - (listify(ids) for ids in layout.get_B0FieldIdentifiers(**base_entities)), + ( + listify(ids) + for ids in layout.get_B0FieldIdentifiers(session=sessions, **base_entities) + ), set(), ) @@ -372,6 +390,15 @@ def find_estimators( B0FieldIdentifier=f'"{b0_id}"', # Double quotes to match JSON, not Python repr regex_search=True, ) + + if no_medic and any( + fmap.entities.get('part') in ('mag', 'phase') for fmap in bare_ids + listed_ids + ): + # ``part``-tagged BOLD under a B0FieldIdentifier is a MEDIC + # source; ``no_medic`` suppresses it (other identifiers stand). + logger.debug('Skipping B0FieldIdentifier %s (MEDIC; no_medic set)', b0_id) + continue + try: e = fm.FieldmapEstimation( [ @@ -444,14 +471,18 @@ def find_estimators( _log_debug_estimation(logger, e, layout.root) estimators.append(e) - # At this point, only single-PE _epi files WITH ``IntendedFor`` can - # be automatically processed. + # At this point, only single-PE _epi/_bold files WITH ``IntendedFor`` + # can be automatically processed. has_intended = () with suppress(ValueError): has_intended = layout.get( **{ **base_entities, - **{'suffix': 'epi', 'IntendedFor': Query.REQUIRED, 'session': sessions}, + **{ + 'suffix': ['epi', 'bold'], + 'IntendedFor': Query.REQUIRED, + 'session': sessions, + }, } ) @@ -511,6 +542,89 @@ def find_estimators( _log_debug_estimation(logger, e, layout.root) estimators.append(e) + # MEDIC: multi-echo BOLD with mag+phase parts — legacy ``IntendedFor`` path. + # + # MEDIC is discovered from BIDS intent metadata only, never from file + # structure alone. The primary, BIDS-RECOMMENDED route is + # ``B0FieldIdentifier`` (handled by Step 1 above, where a complex + # multi-echo BOLD tagged with an identifier — the self-referential pattern + # BIDS endorses for images that estimate their own B0 field — is built as a + # MEDIC estimator). This block is the legacy fallback: when no + # ``B0FieldIdentifier`` is present, a complex multi-echo BOLD that declares + # ``IntendedFor`` is picked up here. Skipped entirely when ``no_medic`` is + # set or when ``B0FieldIdentifier`` metadata already drove discovery. + if not no_medic and not b0_ids: + medic_seed_query = { + **base_entities, + 'session': sessions, + 'suffix': 'bold', + 'echo': Query.REQUIRED, + 'IntendedFor': Query.REQUIRED, + } + + # Query both parts as seeds — datasets vary on which side carries + # ``IntendedFor`` — and rely on the dedup check below to keep the + # estimator unique per (run, echo-set). Each seed is then expanded + # to all matching mag+phase echo siblings of its run. + medic_seeds = [] + for part in ('phase', 'mag'): + with suppress(ValueError): + medic_seeds.extend(layout.get(**{**medic_seed_query, 'part': part})) + + for bold_fmap in medic_seeds: + # Pull every echo + part for this run. ``get_entities()`` already + # includes extension; we override part/echo to widen the query. + run_entities = { + k: v for k, v in bold_fmap.get_entities().items() if k not in ('part', 'echo') + } + run_entities['part'] = ['phase', 'mag'] + run_entities['echo'] = Query.ANY + run_entities['scope'] = base_entities['scope'] + complex_imgs = layout.get(**run_entities) + if not complex_imgs: + continue + + # Dedup against every prior estimator's full source set, not just + # the first complex image — pybids ordering is not contractual, + # and the same run can be seeded twice (once via phase, once via + # mag) when both parts carry ``IntendedFor``. + already_claimed = {str(s.path) for est in estimators for s in est.sources} + if any(str(c.path) in already_claimed for c in complex_imgs): + logger.debug('Skipping MEDIC fmap %s (already in use)', complex_imgs[0].relpath) + continue + + try: + e = fm.FieldmapEstimation( + [ + fm.FieldmapFile( + img.path, + metadata=_filter_metadata(img.get_metadata(), subject), + ) + for img in complex_imgs + ] + ) + except (ValueError, TypeError) as err: + _log_debug_estimator_fail( + logger, 'unnamed MEDIC', list(complex_imgs), layout.root, str(err) + ) + else: + _log_debug_estimation(logger, e, layout.root) + estimators.append(e) + + # Return estimators in a stable, deterministic order so the same dataset + # always yields the same list (Step 1 iterates a ``set`` of + # ``B0FieldIdentifier``s, whose order is otherwise hash-seed dependent). + # + # The ordering encodes one intentional, documented priority: dynamic + # (MEDIC) estimators come first. A consumer that simply walks this list and + # takes the first estimator applicable to a target therefore prefers MEDIC + # over a coexisting static fieldmap. Ties (and all non-MEDIC estimators) + # are then ordered by ``bids_id`` — which preserves discovery order for the + # auto-named heuristic path (``auto_NNNNN`` ids are monotonic) and is + # alphabetical for explicitly named ``B0FieldIdentifier`` estimators. + # Fieldmap-less ANAT estimators are appended after this point and stay last. + estimators.sort(key=lambda e: (not e.is_dynamic, e.bids_id)) + if estimators and not force_fmapless: fmapless = False diff --git a/sdcflows/workflows/apply/dynamic.py b/sdcflows/workflows/apply/dynamic.py new file mode 100644 index 0000000000..5e7f5b2e39 --- /dev/null +++ b/sdcflows/workflows/apply/dynamic.py @@ -0,0 +1,176 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Per-volume distortion correction using a dynamic (4D) fieldmap. + +Counterpart to :func:`~sdcflows.workflows.apply.correction.init_unwarp_wf`, +specialized for fieldmaps that vary across time (typically MEDIC). Where +the static apply path interpolates a single B-spline-encoded field onto +the EPI grid and applies the same warp to every volume, this workflow +takes a 4D Hz fieldmap *already on the EPI grid* and applies a different +warp to each timepoint. +""" + +from nipype.interfaces import utility as niu +from nipype.pipeline import engine as pe +from niworkflows.engine.workflows import LiterateWorkflow as Workflow + +INPUT_FIELDS = ('distorted', 'metadata', 'fmap') + + +def init_dynamic_unwarp_wf( + *, + jacobian=True, + omp_nthreads=1, + name='dynamic_unwarp_wf', +): + r""" + Apply a per-volume 4D fieldmap to unwarp a 4D EPI series. + + Workflow Graph + .. workflow :: + :graph2use: orig + :simple_form: yes + + from sdcflows.workflows.apply.dynamic import init_dynamic_unwarp_wf + wf = init_dynamic_unwarp_wf() + + Parameters + ---------- + jacobian : :obj:`bool` + If :obj:`True`, apply Jacobian determinant correction after + resampling, preserving total signal through compression/expansion + regions of the EPI distortion. Mirrors the + :func:`~sdcflows.workflows.apply.correction.init_unwarp_wf` default. + omp_nthreads : :obj:`int` + Maximum number of parallel volume resamplings. + name : :obj:`str` + Workflow name. + + Inputs + ------ + distorted : :obj:`str` + 4D EPI series to unwarp. Must share frame count with ``fmap``. + metadata : :obj:`dict` + BIDS sidecar metadata. ``TotalReadoutTime`` and + ``PhaseEncodingDirection`` are required. + fmap : :obj:`str` + 4D B\ :sub:`0` field map in Hz, already on the EPI grid (typically + from :func:`~sdcflows.workflows.fit.medic.init_medic_wf`). + + Outputs + ------- + corrected : :obj:`str` + 4D unwarped EPI. + corrected_ref : :obj:`str` + 3D temporal-mean reference of the corrected series, brain-extracted. + corrected_mask : :obj:`str` + Binary brain mask co-registered with ``corrected_ref``. + """ + from niworkflows.interfaces.images import RobustAverage + + from ...interfaces.epi import GetReadoutTime + from ..ancillary import init_brainextraction_wf + + workflow = Workflow(name=name) + + inputnode = pe.Node(niu.IdentityInterface(fields=INPUT_FIELDS), name='inputnode') + outputnode = pe.Node( + niu.IdentityInterface( + fields=['corrected', 'corrected_ref', 'corrected_mask'], + ), + name='outputnode', + ) + + rotime = pe.Node(GetReadoutTime(), name='rotime', run_without_submitting=True) + + # No coregistration step: the dynamic fieldmap is on the EPI grid by + # construction (e.g., warpkit's MEDIC output is computed from the same + # multi-echo acquisition being corrected here), so the static path's + # ``fmap2data_xfm`` plumbing has no analog. + unwarp = pe.Node( + niu.Function( + input_names=[ + 'distorted', + 'fmap', + 'pe_direction', + 'readout_time', + 'jacobian', + 'num_threads', + ], + output_names=['out_file'], + function=_dynamic_unwarp, + ), + name='unwarp', + n_procs=omp_nthreads, + ) + unwarp.inputs.jacobian = jacobian + unwarp.inputs.num_threads = omp_nthreads + + average = pe.Node(RobustAverage(mc_method=None), name='average') + brainextraction_wf = init_brainextraction_wf() + + workflow.connect([ + (inputnode, rotime, [('distorted', 'in_file'), + ('metadata', 'metadata')]), + (inputnode, unwarp, [('distorted', 'distorted'), + ('fmap', 'fmap')]), + (rotime, unwarp, [('pe_direction', 'pe_direction'), + ('readout_time', 'readout_time')]), + (unwarp, average, [('out_file', 'in_file')]), + (unwarp, outputnode, [('out_file', 'corrected')]), + (average, brainextraction_wf, [('out_file', 'inputnode.in_file')]), + (brainextraction_wf, outputnode, [ + ('outputnode.out_file', 'corrected_ref'), + ('outputnode.out_mask', 'corrected_mask'), + ]), + ]) # fmt:skip + + return workflow + + +def _dynamic_unwarp(distorted, fmap, pe_direction, readout_time, jacobian, num_threads): + """Resample a 4D EPI through a per-frame 4D Hz fieldmap on the same grid. + + The 4D fieldmap is handed to :class:`~sdcflows.transform.B0FieldTransform` + as a pre-gridded field (no B-spline reconstruction or coregistration), so it + flows through the same resampling machinery as the static path. + """ + from pathlib import Path + + import nibabel as nb + + from sdcflows.transform import B0FieldTransform + + resampled = B0FieldTransform(mapped=nb.load(fmap)).apply( + distorted, + pe_dir=pe_direction, + ro_time=readout_time, + jacobian=jacobian, + num_threads=num_threads, + ) + # Return a ``str`` (not ``Path``): nipype prunes a node's working dir to the + # files referenced by its string-valued outputs, so a ``PosixPath`` return + # leaves ``corrected.nii.gz`` unrecognized and it gets deleted post-run. + out_file = str(Path('corrected.nii.gz').absolute()) + resampled.to_filename(out_file) + return out_file diff --git a/sdcflows/workflows/apply/tests/test_dynamic.py b/sdcflows/workflows/apply/tests/test_dynamic.py new file mode 100644 index 0000000000..c1c518a159 --- /dev/null +++ b/sdcflows/workflows/apply/tests/test_dynamic.py @@ -0,0 +1,179 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Tests for the per-volume dynamic apply workflow.""" + +from json import loads +from pathlib import Path + +import pytest + +from ..dynamic import INPUT_FIELDS, init_dynamic_unwarp_wf + + +def test_dynamic_unwarp_construct(): + """Build the workflow and verify shape.""" + wf = init_dynamic_unwarp_wf() + assert wf.name == 'dynamic_unwarp_wf' + + inputnode = wf.get_node('inputnode') + outputnode = wf.get_node('outputnode') + assert inputnode is not None and outputnode is not None + assert set(inputnode.outputs.copyable_trait_names()) >= set(INPUT_FIELDS) + assert set(outputnode.inputs.copyable_trait_names()) >= { + 'corrected', + 'corrected_ref', + 'corrected_mask', + } + + for node_name in ('rotime', 'unwarp', 'average'): + assert wf.get_node(node_name) is not None, f'missing node {node_name!r}' + + +def test_dynamic_unwarp_jacobian_flag_propagates(): + """The ``jacobian`` ctor flag forwards to the per-volume resampler.""" + wf = init_dynamic_unwarp_wf(jacobian=False) + unwarp = wf.get_node('unwarp') + assert unwarp.inputs.jacobian is False + + wf = init_dynamic_unwarp_wf(jacobian=True) + unwarp = wf.get_node('unwarp') + assert unwarp.inputs.jacobian is True + + +def test_dynamic_unwarp_matches_static(tmp_path, monkeypatch): + """For a 4D fmap with identical frames, per-volume resampling matches the + static path frame-by-frame. + + This pins the pre-gridded :class:`sdcflows.transform.B0FieldTransform` path + to the same Hz→VSM + scipy.ndimage convention as the rest of the codebase — + if the static path ever changes its sign or pe_info handling, this test + catches the drift. + """ + import nibabel as nb + import numpy as np + + from sdcflows.transform import B0FieldTransform, _sdc_unwarp + from sdcflows.utils.tools import ensure_positive_cosines + + monkeypatch.chdir(tmp_path) + + rng = np.random.default_rng(0) + shape = (5, 7, 5) + n_frames = 3 + affine = np.eye(4) + + fmap_3d = rng.normal(scale=0.5, size=shape).astype('float32') + fmap_4d = np.broadcast_to(fmap_3d[..., None], (*shape, n_frames)).astype('float32') + distorted = rng.normal(size=(*shape, n_frames)).astype('float32') + + distorted_path = tmp_path / 'distorted.nii.gz' + fmap_path = tmp_path / 'fmap.nii.gz' + nb.Nifti1Image(distorted, affine).to_filename(distorted_path) + nb.Nifti1Image(fmap_4d, affine).to_filename(fmap_path) + + resampled = B0FieldTransform(mapped=nb.load(str(fmap_path))).apply( + str(distorted_path), + pe_dir='j', + ro_time=0.1, + jacobian=True, + order=1, + prefilter=False, + num_threads=1, + allow_negative=True, + ) + out_data = np.asanyarray(resampled.dataobj) + + # Run the same primitive directly, per-frame, with no parallelism. + img, axcodes = ensure_positive_cosines(nb.load(str(distorted_path))) + voxcoords = np.indices(shape, dtype='float32') + pe_axis = 'ijk'.index('j') + flip = (axcodes[pe_axis] in 'LPI') ^ False + pe_info = (pe_axis, -0.1 if flip else 0.1) + expected = np.stack( + [ + _sdc_unwarp( + distorted[..., t], + voxcoords.copy(), + pe_info, + None, + jacobian=True, + fmap_hz=fmap_3d, + output_dtype='float32', + order=1, + prefilter=False, + ) + for t in range(n_frames) + ], + axis=-1, + ) + assert np.allclose(out_data, expected, atol=1e-5) + + +@pytest.mark.veryslow +def test_dynamic_unwarp_run( + tmpdir, datadir, workdir, medic_fixture, medic_test_volumes, truncate_to_volumes +): + """End-to-end run: estimate via MEDIC then apply via this workflow. + + Skipped without ``warpkit`` or without the multi-echo fixture under + ``$TEST_DATA_HOME``. See ``test_medic_run`` for fetch instructions. + """ + pytest.importorskip('warpkit') + + from sdcflows.workflows.fit.medic import init_medic_wf + + dataset, pattern = medic_fixture + full_pattern = f'{dataset}/{pattern}' + magnitude_files = sorted(Path(datadir).glob(full_pattern)) + if not magnitude_files: + pytest.skip(f'no MEDIC fixtures found under {datadir}/{dataset}') + + phase_files = [f.with_name(f.name.replace('part-mag', 'part-phase')) for f in magnitude_files] + metadata = [ + loads(f.with_name(f.name.replace('.nii.gz', '.json')).read_text()) for f in phase_files + ] + + tmpdir.chdir() + trunc_dir = Path(str(tmpdir)) / 'trunc' + trunc_dir.mkdir(exist_ok=True) + magnitude_files = truncate_to_volumes(magnitude_files, medic_test_volumes, trunc_dir) + phase_files = truncate_to_volumes(phase_files, medic_test_volumes, trunc_dir) + + fit_wf = init_medic_wf(omp_nthreads=2) + fit_wf.inputs.inputnode.magnitude = [str(f) for f in magnitude_files] + fit_wf.inputs.inputnode.phase = [str(f) for f in phase_files] + fit_wf.inputs.inputnode.metadata = metadata + + apply_wf = init_dynamic_unwarp_wf(omp_nthreads=2) + # Use the first-echo magnitude as the distorted target. + apply_wf.inputs.inputnode.distorted = str(magnitude_files[0]) + apply_wf.inputs.inputnode.metadata = metadata[0] + + from niworkflows.engine.workflows import LiterateWorkflow as Workflow + + wf = Workflow(name=f'medic_apply_{magnitude_files[0].stem.replace(".nii", "")}') + wf.connect([(fit_wf, apply_wf, [('outputnode.fmap', 'inputnode.fmap')])]) + + if workdir: + wf.base_dir = str(workdir) + wf.run(plugin='Linear') diff --git a/sdcflows/workflows/base.py b/sdcflows/workflows/base.py index 40f938b6a6..bc281805b9 100644 --- a/sdcflows/workflows/base.py +++ b/sdcflows/workflows/base.py @@ -82,6 +82,7 @@ def init_fmap_preproc_wf( """ from sdcflows.fieldmaps import EstimatorType + from sdcflows.workflows.fit.medic import INPUT_FIELDS as _medic_fields from sdcflows.workflows.fit.pepolar import INPUT_FIELDS as _pepolar_fields from sdcflows.workflows.fit.syn import INPUT_FIELDS as _syn_fields from sdcflows.workflows.outputs import init_fmap_derivatives_wf, init_fmap_reports_wf @@ -89,11 +90,19 @@ def init_fmap_preproc_wf( INPUT_FIELDS = { EstimatorType.ANAT: _syn_fields, EstimatorType.PEPOLAR: _pepolar_fields, + EstimatorType.MEDIC: _medic_fields, } workflow = Workflow(name=name) - out_fields = ('fmap', 'fmap_coeff', 'fmap_ref', 'fmap_mask', 'fmap_id', 'method') + out_fields = ( + 'fmap', + 'fmap_coeff', + 'fmap_ref', + 'fmap_mask', + 'fmap_id', + 'method', + ) out_merge = {f: pe.Node(niu.Merge(len(estimators)), name=f'out_merge_{f}') for f in out_fields} # Fieldmaps and coefficient files can come in pairs, ensure they are not flattened out_merge['fmap'].inputs.no_flatten = True @@ -134,9 +143,12 @@ def init_fmap_preproc_wf( ) out_map.inputs.fmap_id = estimator.bids_id + # Dynamic estimators (currently MEDIC) emit a 4D fieldmap directly on + # the EPI grid; no B-spline coefficient representation is produced. + is_dynamic = estimator.is_dynamic fmap_derivatives_wf = init_fmap_derivatives_wf( output_dir=str(output_dir), - write_coeff=True, + write_coeff=not is_dynamic, write_mask=True, bids_fmap_id=estimator.bids_id, name=f'fmap_derivatives_wf_{estimator.sanitized_id}', @@ -162,13 +174,25 @@ def init_fmap_preproc_wf( (inputnode, est_wf, [(f, f"inputnode.{f}") for f in fields]) ]) # fmt:skip + deriv_conns = [ + ('outputnode.fmap', 'inputnode.fieldmap'), + ('outputnode.fmap_ref', 'inputnode.fmap_ref'), + ('outputnode.fmap_mask', 'inputnode.fmap_mask'), + ] + out_map_conns = [ + ('outputnode.fieldmap', 'fmap'), + ('outputnode.fmap_ref', 'fmap_ref'), + ('outputnode.fmap_mask', 'fmap_mask'), + ] + if not is_dynamic: + deriv_conns.append(('outputnode.fmap_coeff', 'inputnode.fmap_coeff')) + out_map_conns.append(('outputnode.fmap_coeff', 'fmap_coeff')) + else: + # Keep the merge node aligned across estimators that don't emit coeffs. + out_map.inputs.fmap_coeff = None + workflow.connect([ - (est_wf, fmap_derivatives_wf, [ - ("outputnode.fmap", "inputnode.fieldmap"), - ("outputnode.fmap_ref", "inputnode.fmap_ref"), - ("outputnode.fmap_coeff", "inputnode.fmap_coeff"), - ("outputnode.fmap_mask", "inputnode.fmap_mask"), - ]), + (est_wf, fmap_derivatives_wf, deriv_conns), (est_wf, fmap_reports_wf, [ ("outputnode.fmap", "inputnode.fieldmap"), ("outputnode.fmap_ref", "inputnode.fmap_ref"), @@ -177,12 +201,7 @@ def init_fmap_preproc_wf( (est_wf, out_map, [ ("outputnode.method", "method") ]), - (fmap_derivatives_wf, out_map, [ - ("outputnode.fieldmap", "fmap"), - ("outputnode.fmap_ref", "fmap_ref"), - ("outputnode.fmap_coeff", "fmap_coeff"), - ("outputnode.fmap_mask", "fmap_mask"), - ]), + (fmap_derivatives_wf, out_map, out_map_conns), ]) # fmt:skip for field, mergenode in out_merge.items(): diff --git a/sdcflows/workflows/fit/medic.py b/sdcflows/workflows/fit/medic.py new file mode 100644 index 0000000000..2175994376 --- /dev/null +++ b/sdcflows/workflows/fit/medic.py @@ -0,0 +1,211 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""MEDIC dynamic distortion correction (multi-echo phase + magnitude). + +Backed by `warpkit `__, a standard +``sdcflows`` dependency (on Python >= 3.11) that carries a Washington +University **non-commercial** license — commercial use requires a separate +WUSTL OTM agreement. Importing this module does not require warpkit; the +dependency is only resolved when the +:class:`~sdcflows.interfaces.warpkit.UnwrapPhase` and +:class:`~sdcflows.interfaces.warpkit.ComputeFieldmap` interfaces actually run. +""" + +from nipype.interfaces import utility as niu +from nipype.pipeline import engine as pe +from niworkflows.engine.workflows import LiterateWorkflow as Workflow + +INPUT_FIELDS = ('phase', 'magnitude', 'metadata') + + +def init_medic_wf( + omp_nthreads=1, + sloppy=False, + debug=False, + name='medic_wf', + **kwargs, +): + """ + Estimate a fieldmap via MEDIC from multi-echo magnitude + phase EPI. + + Workflow Graph + .. workflow :: + :graph2use: orig + :simple_form: yes + + from sdcflows.workflows.fit.medic import init_medic_wf + wf = init_medic_wf() + + Parameters + ---------- + omp_nthreads : :obj:`int` + Maximum number of threads warpkit may use. + sloppy : :obj:`bool` + Accepted for parity with other ``init_*_wf`` constructors; currently + unused for MEDIC. + debug : :obj:`bool` + Pass through to :class:`~sdcflows.interfaces.warpkit.UnwrapPhase`. + name : :obj:`str` + Workflow name. + + Inputs + ------ + phase : :obj:`list` of :obj:`str` + Phase NIfTI per echo. + magnitude : :obj:`list` of :obj:`str` + Magnitude NIfTI per echo. + metadata : :obj:`list` of :obj:`dict` + BIDS sidecar dicts, one per echo. Must contain ``EchoTime``, + ``TotalReadoutTime``, and ``PhaseEncodingDirection``. + + Outputs + ------- + fmap : :obj:`str` + 4D :math:`B_0` map in Hz, one volume per timepoint, already on the + EPI grid. Consumers must dispatch on dimensionality (3D for static + estimators, 4D for MEDIC) when applying. + fmap_ref : :obj:`str` + First-echo magnitude series, unprocessed: one volume per timepoint + matching ``fmap``. + fmap_mask : :obj:`str` + 4D binary brain mask (one mask per timepoint) as produced by MEDIC + (``warpkit``), aligned with ``fmap``. + method : :obj:`str` + Short description string. + + """ + # Project-internal imports only — none of these load warpkit at module + # import time. The warpkit dependency is resolved lazily inside the + # MEDIC interfaces at run time. + from ...interfaces.warpkit import ComputeFieldmap, UnwrapPhase + + workflow = Workflow(name=name) + workflow.__desc__ = """\ +A dynamic *B0* nonuniformity map was estimated from multi-echo +magnitude and phase EPI series using MEDIC [@van2026medic], as implemented in +``warpkit``. +""" + + inputnode = pe.Node(niu.IdentityInterface(fields=INPUT_FIELDS), name='inputnode') + outputnode = pe.Node( + niu.IdentityInterface( + fields=[ + 'fmap', + 'fmap_ref', + 'fmap_mask', + 'method', + ], + ), + name='outputnode', + ) + outputnode.inputs.method = 'MEDIC (multi-echo dynamic distortion correction)' + + # Pull echo_times / TRT / PED from sidecar dicts so warpkit gets them as + # direct args. (The interfaces also accept JSON sidecar paths, but the + # upstream sdcflows layer passes dicts.) + extract_meta = pe.Node( + niu.Function( + input_names=['metadata'], + output_names=['echo_times', 'total_readout_time', 'phase_encoding_direction'], + function=_unpack_metadata, + ), + name='extract_meta', + run_without_submitting=True, + ) + + # Two-stage warpkit path: UnwrapPhase exposes per-frame masks, which + # ComputeFieldmap then consumes. The one-shot MEDIC interface bundles + # both but doesn't materially differ for the fieldmap outputs we need. + unwrap = pe.Node( + UnwrapPhase(n_cpus=omp_nthreads, debug=debug), + name='unwrap', + n_procs=omp_nthreads, + ) + + # ComputeFieldmap doesn't expose a ``debug`` input — only UnwrapPhase + # does, so the asymmetry is intentional. + compute_fmap = pe.Node( + ComputeFieldmap(n_cpus=omp_nthreads), + name='compute_fmap', + n_procs=omp_nthreads, + ) + + # ``fmap_ref`` is just the first-echo magnitude series, passed through + # untouched. ``fmap_mask`` reuses the per-frame masks MEDIC already + # computes during phase unwrapping, so both track the per-volume fieldmap + # without any extra N4/skull-strip work. + pick_mag1 = pe.Node( + niu.Function( + input_names=['in_list'], + output_names=['out_file'], + function=_first, + ), + name='pick_mag1', + run_without_submitting=True, + ) + + # fmt: off + workflow.connect([ + (inputnode, extract_meta, [('metadata', 'metadata')]), + (inputnode, unwrap, [('phase', 'phase'), + ('magnitude', 'magnitude')]), + (extract_meta, unwrap, [('echo_times', 'echo_times')]), + (inputnode, compute_fmap, [('magnitude', 'magnitude')]), + (unwrap, compute_fmap, [('unwrapped', 'unwrapped'), + ('masks', 'masks')]), + (extract_meta, compute_fmap, [ + ('echo_times', 'echo_times'), + ('total_readout_time', 'total_readout_time'), + ('phase_encoding_direction', 'phase_encoding_direction'), + ]), + (compute_fmap, outputnode, [('fieldmap', 'fmap')]), + (inputnode, pick_mag1, [('magnitude', 'in_list')]), + (pick_mag1, outputnode, [('out_file', 'fmap_ref')]), + (unwrap, outputnode, [('masks', 'fmap_mask')]), + ]) + # fmt: on + + return workflow + + +def _unpack_metadata(metadata): + """Pull echo times (s→ms), TRT, and PE direction from BIDS sidecars.""" + if not metadata: + raise ValueError('MEDIC requires per-echo metadata.') + if len(metadata) < 2: + raise ValueError( + f'MEDIC requires at least two echoes; got {len(metadata)}. ' + '(FieldmapEstimation enforces this for wrangler-built workflows; ' + 'this guard catches direct callers that bypass it.)' + ) + echo_times = [float(m['EchoTime']) * 1000.0 for m in metadata] + total_readout_time = float(metadata[0]['TotalReadoutTime']) + phase_encoding_direction = metadata[0]['PhaseEncodingDirection'] + peds = {m['PhaseEncodingDirection'] for m in metadata} + if len(peds) > 1: + raise ValueError(f'MEDIC echoes must share PhaseEncodingDirection; got {sorted(peds)}.') + return echo_times, total_readout_time, phase_encoding_direction + + +def _first(in_list): + return in_list[0] if in_list else None diff --git a/sdcflows/workflows/fit/tests/test_medic.py b/sdcflows/workflows/fit/tests/test_medic.py new file mode 100644 index 0000000000..b171865f62 --- /dev/null +++ b/sdcflows/workflows/fit/tests/test_medic.py @@ -0,0 +1,146 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Tests for the MEDIC dynamic fieldmap workflow.""" + +from json import loads +from pathlib import Path + +import pytest + +from ..medic import INPUT_FIELDS, _first, _unpack_metadata, init_medic_wf + + +def test_medic_construct(): + """Build the workflow and verify its surface — no warpkit required. + + This guards against module-load regressions and confirms the inputnode/ + outputnode shape that ``init_fmap_preproc_wf`` depends on. + """ + wf = init_medic_wf() + assert wf.name == 'medic_wf' + + inputnode = wf.get_node('inputnode') + outputnode = wf.get_node('outputnode') + assert inputnode is not None and outputnode is not None + assert set(inputnode.outputs.copyable_trait_names()) >= set(INPUT_FIELDS) + assert set(outputnode.inputs.copyable_trait_names()) >= { + 'fmap', + 'fmap_ref', + 'fmap_mask', + 'method', + } + # method is set unconditionally at construction. + assert outputnode.inputs.method.startswith('MEDIC') + + # Core nodes wired in the order the workflow describes. + for name in ( + 'extract_meta', + 'unwrap', + 'compute_fmap', + 'pick_mag1', + ): + assert wf.get_node(name) is not None, f'missing node {name!r}' + + +def test_unpack_metadata_converts_te_to_ms(): + metadata = [ + {'EchoTime': 0.0142, 'TotalReadoutTime': 0.5, 'PhaseEncodingDirection': 'j'}, + {'EchoTime': 0.03893, 'TotalReadoutTime': 0.5, 'PhaseEncodingDirection': 'j'}, + ] + tes, trt, ped = _unpack_metadata(metadata) + assert tes == [pytest.approx(14.2), pytest.approx(38.93)] + assert trt == 0.5 + assert ped == 'j' + + +def test_unpack_metadata_rejects_single_echo(): + metadata = [{'EchoTime': 0.0142, 'TotalReadoutTime': 0.5, 'PhaseEncodingDirection': 'j'}] + with pytest.raises(ValueError, match='at least two echoes'): + _unpack_metadata(metadata) + + +def test_unpack_metadata_rejects_mixed_pe(): + metadata = [ + {'EchoTime': 0.0142, 'TotalReadoutTime': 0.5, 'PhaseEncodingDirection': 'j'}, + {'EchoTime': 0.03893, 'TotalReadoutTime': 0.5, 'PhaseEncodingDirection': 'j-'}, + ] + with pytest.raises(ValueError, match='PhaseEncodingDirection'): + _unpack_metadata(metadata) + + +def test_unpack_metadata_rejects_empty(): + with pytest.raises(ValueError, match='per-echo metadata'): + _unpack_metadata([]) + + +def test_first_helper(): + """``_first`` returns the head of the list or ``None`` when empty.""" + assert _first(['a', 'b', 'c']) == 'a' + assert _first([]) is None + + +@pytest.mark.veryslow +def test_medic_run( + tmpdir, datadir, workdir, outdir, medic_fixture, medic_test_volumes, truncate_to_volumes +): + """End-to-end MEDIC run on a real multi-echo BOLD. + + Skipped if ``warpkit`` is unavailable (e.g. Python 3.10) or if the + expected dataset is not present under ``$TEST_DATA_HOME``. To run it, + stage the dataset, e.g. + + .. code-block:: console + + # ds006926: OpenNeuro multi-echo mag+phase BOLD (publicly available) + cd $TEST_DATA_HOME + datalad install https://github.com/OpenNeuroDatasets/ds006926.git + datalad get -d ds006926 sub-a01/func/sub-a01_task-VisMot_acq-tr1800_* + + """ + pytest.importorskip('warpkit') + + dataset, pattern = medic_fixture + full_pattern = f'{dataset}/{pattern}' + magnitude_files = sorted(Path(datadir).glob(full_pattern)) + if not magnitude_files: + pytest.skip(f'no MEDIC fixtures found under {datadir}/{dataset}') + + phase_files = [f.with_name(f.name.replace('part-mag', 'part-phase')) for f in magnitude_files] + metadata = [ + loads(f.with_name(f.name.replace('.nii.gz', '.json')).read_text()) for f in phase_files + ] + + tmpdir.chdir() + trunc_dir = Path(str(tmpdir)) / 'trunc' + trunc_dir.mkdir(exist_ok=True) + magnitude_files = truncate_to_volumes(magnitude_files, medic_test_volumes, trunc_dir) + phase_files = truncate_to_volumes(phase_files, medic_test_volumes, trunc_dir) + + medic_wf = init_medic_wf(omp_nthreads=2) + medic_wf.inputs.inputnode.magnitude = [str(f) for f in magnitude_files] + medic_wf.inputs.inputnode.phase = [str(f) for f in phase_files] + medic_wf.inputs.inputnode.metadata = metadata + + if workdir: + medic_wf.base_dir = str(workdir) + medic_wf.run(plugin='Linear') diff --git a/sdcflows/workflows/outputs.py b/sdcflows/workflows/outputs.py index de72d4cb34..116e3b60d7 100644 --- a/sdcflows/workflows/outputs.py +++ b/sdcflows/workflows/outputs.py @@ -147,6 +147,7 @@ def init_fmap_derivatives_wf( One or more fieldmap file(s) of the BIDS dataset that will serve for naming reference. fieldmap The preprocessed fieldmap, in its original space with Hz units. + Can be 3D (static estimators) or 4D (dynamic estimators such as MEDIC). fmap_coeff Field coefficient(s) file(s) fmap_ref @@ -160,16 +161,33 @@ def init_fmap_derivatives_wf( workflow = pe.Workflow(name=name) inputnode = pe.Node( niu.IdentityInterface( - fields=['source_files', 'fieldmap', 'fmap_coeff', 'fmap_ref', 'fmap_mask', 'fmap_meta'] + fields=[ + 'source_files', + 'fieldmap', + 'fmap_coeff', + 'fmap_ref', + 'fmap_mask', + 'fmap_meta', + ] ), name='inputnode', ) outputnode = pe.Node( - niu.IdentityInterface(fields=['fieldmap', 'fmap_coeff', 'fmap_ref', 'fmap_mask']), + niu.IdentityInterface( + fields=[ + 'fieldmap', + 'fmap_coeff', + 'fmap_ref', + 'fmap_mask', + ] + ), name='outputnode', ) - merge_fmap = pe.Node(MergeSeries(), name='merge_fmap') + # ``allow_4D`` lets MEDIC's 4D Hz fieldmap pass through alongside the + # 3D outputs of the static estimators — MergeSeries splits 4D inputs + # into per-frame 3D and re-concatenates them. + merge_fmap = pe.Node(MergeSeries(allow_4D=True), name='merge_fmap') ds_reference = pe.Node( DerivativesDataSink( diff --git a/sdcflows/workflows/tests/test_outputs.py b/sdcflows/workflows/tests/test_outputs.py new file mode 100644 index 0000000000..a83ece9a98 --- /dev/null +++ b/sdcflows/workflows/tests/test_outputs.py @@ -0,0 +1,43 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Construction-only tests for :mod:`sdcflows.workflows.outputs`.""" + +from ..outputs import init_fmap_derivatives_wf + + +def test_fmap_derivatives_wf_default(tmp_path): + """Default workflow should expose only the static fieldmap sinks.""" + wf = init_fmap_derivatives_wf(output_dir=str(tmp_path)) + assert wf.get_node('ds_fieldmap') is not None + assert wf.get_node('ds_reference') is not None + # write_mask off by default. + assert wf.get_node('ds_mask') is None + + +def test_fmap_derivatives_wf_merge_fmap_allows_4d(tmp_path): + """``merge_fmap`` must accept 4D inputs so MEDIC's per-frame Hz fmap flows + through the same sink as the static estimators' 3D fmaps.""" + wf = init_fmap_derivatives_wf(output_dir=str(tmp_path)) + merge_fmap = wf.get_node('merge_fmap') + assert merge_fmap is not None + assert merge_fmap.inputs.allow_4D is True diff --git a/tox.ini b/tox.ini index 1cd1963132..3c65377653 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,8 @@ pass_env = CLICOLOR CLICOLOR_FORCE PYTHON_GIL -extras = tests +extras = + tests setenv = pre: PIP_EXTRA_INDEX_URL=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pre: UV_INDEX=https://pypi.anaconda.org/scientific-python-nightly-wheels/simple