Skip to content

Commit

Permalink
RF: Add resolution parameter when retrieving MNIInfant files
Browse files Browse the repository at this point in the history
  • Loading branch information
mgxd committed Aug 28, 2024
1 parent dc42e25 commit 0bd5064
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
15 changes: 8 additions & 7 deletions nibabies/workflows/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,10 +515,12 @@ def init_single_subject_wf(
]) # fmt:skip

if cifti_output and 'MNIInfant' in [ref.space for ref in spaces.references]:
mniinfant_res = 2 if config.workflow.cifti_output == '91k' else 1

select_MNIInfant_xfm = pe.Node(
KeySelect(
fields=['anat2std_xfm', 'std2anat_xfm'],
key=get_MNIInfant_key(spaces),
key=get_MNIInfant_key(spaces, mniinfant_res),
),
name='select_MNIInfant_xfm',
run_without_submitting=True,
Expand Down Expand Up @@ -950,11 +952,10 @@ def get_estimator(layout, fname):
return field_source


def get_MNIInfant_key(spaces: SpatialReferences) -> str:
def get_MNIInfant_key(spaces: SpatialReferences, res: str | int) -> str:
"""Parse spaces and return matching MNIInfant space, including cohort."""
for space in spaces.references:
# str formats as <reference.name>:<reference.spec>
if 'MNIInfant' in str(space) and 'res-2' in str(space):
return space.fullname
for ref in spaces.references:
if ref.space == 'MNIInfant' and f'res-{res}' in str(ref):
return ref.fullname

raise KeyError(f'MNIInfant (resolution 2x2x2) not found in SpatialReferences: {spaces}')
raise KeyError(f'MNIInfant (resolution {res}) not found in SpatialReferences: {spaces}')
30 changes: 14 additions & 16 deletions nibabies/workflows/bold/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ def init_bold_wf(
]),
]) # fmt:skip

if config.workflow.cifti_output:
cifti_output = config.workflow.cifti_output
if cifti_output:
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms

from nibabies.workflows.bold.alignment import (
Expand All @@ -581,7 +582,7 @@ def init_bold_wf(
)

bold_fsLR_resampling_wf = init_bold_fsLR_resampling_wf(
grayord_density=config.workflow.cifti_output,
grayord_density=cifti_output,
omp_nthreads=omp_nthreads,
mem_gb=mem_gb['resampled'],
)
Expand Down Expand Up @@ -615,7 +616,7 @@ def init_bold_wf(
subcortical_mni_alignment_wf = init_subcortical_mni_alignment_wf()

bold_grayords_wf = init_bold_grayords_wf(
grayord_density=config.workflow.cifti_output,
grayord_density=cifti_output,
repetition_time=all_metadata[0]['RepetitionTime'],
)

Expand All @@ -624,7 +625,7 @@ def init_bold_wf(
base_directory=output_dir,
dismiss_entities=DEFAULT_DISMISS_ENTITIES,
space='fsLR',
density=config.workflow.cifti_output,
density=cifti_output,
suffix='bold',
compress=False,
TaskName=all_metadata[0].get('TaskName'),
Expand All @@ -635,7 +636,8 @@ def init_bold_wf(
)
ds_bold_cifti.inputs.source_file = bold_file

inputnode.inputs.mniinfant_mask = get_MNIInfant_mask(spaces)
mniinfant_res = 2 if cifti_output == '91k' else 1
inputnode.inputs.mniinfant_mask = get_MNIInfant_mask(spaces, mniinfant_res)

workflow.connect([
# Resample BOLD to MNI152NLin6Asym, may duplicate bold_std_wf above
Expand Down Expand Up @@ -747,11 +749,11 @@ def init_bold_wf(
]) # fmt:skip

# MG: Carpetplot workflow only work with CIFTI
if config.workflow.cifti_output:
if cifti_output:
carpetplot_wf = init_carpetplot_wf(
mem_gb=mem_gb['resampled'],
metadata=all_metadata[0],
cifti_output=config.workflow.cifti_output,
cifti_output=cifti_output,
name='carpetplot_wf',
)

Expand Down Expand Up @@ -847,24 +849,20 @@ def _read_json(in_file):
return loads(Path(in_file).read_text())


def get_MNIInfant_mask(spaces: 'SpatialReferences') -> str:
def get_MNIInfant_mask(spaces: 'SpatialReferences', res: str | int) -> str:
"""Parse spaces and return matching MNIInfant space, including cohort."""
import templateflow.api as tf

mask = None
for ref in spaces.references:
# str formats as <reference.name>:<reference.spec>
if ref.space == 'MNIInfant' and ref.spec.get('res', '') != 'native':
mask = str(
if ref.space == 'MNIInfant' and f'res-{res}' in str(ref):
return str(
tf.get(
'MNIInfant',
cohort=ref.spec['cohort'],
resolution=2,
resolution=res,
desc='brain',
suffix='mask',
)
)

if mask is None:
raise FileNotFoundError('MNIInfant brain mask not found.')
return mask
raise FileNotFoundError(f'MNIInfant mask (resolution {res}) not found.')

0 comments on commit 0bd5064

Please sign in to comment.