From f92efa9898526cc89c278eadc5452486e7b3560b Mon Sep 17 00:00:00 2001 From: Matt McCormick Date: Wed, 22 May 2024 16:40:38 -0400 Subject: [PATCH] ENH: Add DisplacementFieldSubsamplingFactor Output ANTs displacement fields are very dense, sampled with the fixed image. This is generally over-sampled for the purpose of downstream use. Memory usage is very high and serialization and deserialization takes quite some time. Add a DisplacementFieldSubsamplingFactor parameter for downsampling the resulting displacement fields. This is applied in all directions for both the forward and inverse transform. If the DisplacementFieldSubsamplingFactor is greater than 1, this is applied. The current default is 2. This uses the itk::DisplacementFieldTransformParametersAdapter. In the future, we may want to increase the default, and / or use the itk::GaussianSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor with larger factors to avoid aliasing. --- include/itkANTSRegistration.h | 19 +++++++++++++-- include/itkANTSRegistration.hxx | 42 ++++++++++++++++++++++++++++++--- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/include/itkANTSRegistration.h b/include/itkANTSRegistration.h index a055d51..5692c0d 100644 --- a/include/itkANTSRegistration.h +++ b/include/itkANTSRegistration.h @@ -23,6 +23,7 @@ #include "itkCompositeTransform.h" #include "itkDataObjectDecorator.h" #include "itkantsRegistrationHelper.h" +#include "itkDisplacementFieldTransformParametersAdaptor.h" namespace itk { @@ -240,11 +241,17 @@ class ANTSRegistration : public ProcessObject /** Set/Get the optimizer weights. When set, this allows restricting the optimization * of the displacement field, translation, rigid or affine transform on a per-component basis. * For example, to limit the deformation or rotation of 3-D volume to the first two dimensions, - * specify a weight vector of ‘(1,1,0)’ for a 3D deformation field + * specify a weight vector of ‘(1,1,0)’ for a 3D displacement field * or ‘(1,1,0,1,1,0)’ for a rigid transformation. */ itkSetMacro(RestrictTransformation, std::vector); itkGetConstReferenceMacro(RestrictTransformation, std::vector); + /** Set/Get the subsampling factor for displacement fields results. + * A factor of 1 results in no subsampling. This is applied in all dimensions. + * The default is 2. */ + itkSetMacro(DisplacementFieldSubsamplingFactor, unsigned int); + itkGetMacro(DisplacementFieldSubsamplingFactor, unsigned int); + virtual DecoratedOutputTransformType * GetOutput(DataObjectPointerArraySizeType i); virtual const DecoratedOutputTransformType * @@ -286,6 +293,10 @@ class ANTSRegistration : public ProcessObject DataObjectPointer MakeOutput(DataObjectPointerArraySizeType) override; using RegistrationHelperType = ::ants::RegistrationHelper; using InternalImageType = typename RegistrationHelperType::ImageType; // float or double pixels + using DisplacementFieldTransformType = typename RegistrationHelperType::DisplacementFieldTransformType; + using DisplacementFieldType = typename DisplacementFieldTransformType::DisplacementFieldType; + using DisplacementFieldTransformParametersAdaptorType = + DisplacementFieldTransformParametersAdaptor; template typename InternalImageType::Pointer @@ -346,6 +357,7 @@ class ANTSRegistration : public ProcessObject unsigned int m_Radius{ 4 }; bool m_CollapseCompositeTransform{ true }; bool m_MaskAllStages{ false }; + unsigned int m_DisplacementFieldSubsamplingFactor{ 2 }; std::vector m_SynIterations{ 40, 20, 0 }; std::vector m_AffineIterations{ 2100, 1200, 1200, 10 }; @@ -355,7 +367,10 @@ class ANTSRegistration : public ProcessObject std::vector m_RestrictTransformation; private: - typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() }; + typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() }; + typename DisplacementFieldTransformParametersAdaptorType::Pointer m_DisplacementFieldAdaptor{ + DisplacementFieldTransformParametersAdaptorType::New() + }; #ifdef ITK_USE_CONCEPT_CHECKING static_assert(TFixedImage::ImageDimension == TMovingImage::ImageDimension, diff --git a/include/itkANTSRegistration.hxx b/include/itkANTSRegistration.hxx index e3a7e04..a24ad39 100644 --- a/include/itkANTSRegistration.hxx +++ b/include/itkANTSRegistration.hxx @@ -67,6 +67,7 @@ ANTSRegistration::PrintSelf(std os << indent << "Radius: " << this->m_Radius << std::endl; os << indent << "CollapseCompositeTransform: " << (this->m_CollapseCompositeTransform ? "On" : "Off") << std::endl; os << indent << "MaskAllStages: " << (this->m_MaskAllStages ? "On" : "Off") << std::endl; + os << indent << "DisplacementFieldSubsamplingFactor: " << this->m_DisplacementFieldSubsamplingFactor << std::endl; os << indent << "SynIterations: " << this->m_SynIterations << std::endl; os << indent << "AffineIterations: " << this->m_AffineIterations << std::endl; @@ -257,8 +258,8 @@ ANTSRegistration::MakeOutput(Da template template auto -ANTSRegistration::CastImageToInternalType( - const TImage * inputImage) -> typename InternalImageType::Pointer +ANTSRegistration::CastImageToInternalType(const TImage * inputImage) -> + typename InternalImageType::Pointer { using CastFilterType = CastImageFilter; typename CastFilterType::Pointer castFilter = CastFilterType::New(); @@ -606,7 +607,7 @@ ANTSRegistration::GenerateData( { itkExceptionMacro(<< "Unsupported transform type: " << this->GetTypeOfTransform()); } - this->UpdateProgress(0.95); + this->UpdateProgress(0.90); typename OutputTransformType::Pointer forwardTransform = m_Helper->GetModifiableCompositeTransform(); if (m_CollapseCompositeTransform) @@ -615,6 +616,41 @@ ANTSRegistration::GenerateData( } this->SetForwardTransform(forwardTransform); + if (m_DisplacementFieldSubsamplingFactor > 1) + { + using TransformType = typename OutputTransformType::TransformType; + for (unsigned int i = 0; i < forwardTransform->GetNumberOfTransforms(); ++i) + { + typename TransformType::Pointer transform = forwardTransform->GetNthTransform(i); + typename DisplacementFieldTransformType::Pointer displacementFieldTransform = + dynamic_cast(transform.GetPointer()); + if (displacementFieldTransform) + { + // The transform is a DisplacementFieldTransform + displacementFieldTransform->Print(std::cout, 3); + const auto displacementField = displacementFieldTransform->GetDisplacementField(); + m_DisplacementFieldAdaptor->SetTransform(displacementFieldTransform); + m_DisplacementFieldAdaptor->SetRequiredOrigin(displacementField->GetOrigin()); + m_DisplacementFieldAdaptor->SetRequiredDirection(displacementField->GetDirection()); + auto requiredSize = displacementField->GetLargestPossibleRegion().GetSize(); + for (unsigned int i = 0; i < requiredSize.GetSizeDimension(); ++i) + { + requiredSize[i] /= m_DisplacementFieldSubsamplingFactor; + } + m_DisplacementFieldAdaptor->SetRequiredSize(requiredSize); + auto requiredSpacing = displacementField->GetSpacing(); + for (unsigned int i = 0; i < requiredSpacing.GetVectorDimension(); ++i) + { + requiredSpacing[i] *= m_DisplacementFieldSubsamplingFactor; + } + m_DisplacementFieldAdaptor->SetRequiredSpacing(requiredSpacing); + m_DisplacementFieldAdaptor->AdaptTransformParameters(); + displacementFieldTransform->Print(std::cout, 3); + } + } + } + this->UpdateProgress(0.95); + typename OutputTransformType::Pointer inverseTransform = OutputTransformType::New(); if (forwardTransform->GetInverse(inverseTransform)) {