Skip to content

Commit

Permalink
ENH: Add DisplacementFieldSubsamplingFactor
Browse files Browse the repository at this point in the history
Output ANTs displacement fields are very dense, sampled with the fixed
image. This is generally over-sampled for the purpose of downstream use.
Memory usage is very high and serialization and deserialization takes
quite some time.

Add a DisplacementFieldSubsamplingFactor parameter for downsampling the
resulting displacement fields. This is applied in all directions for
both the forward and inverse transform. If the
DisplacementFieldSubsamplingFactor is greater than 1, this is applied.
The current default is 2.

This uses the itk::DisplacementFieldTransformParametersAdapter. In the
future, we may want to increase the default, and / or use the
itk::GaussianSmoothingOnUpdateDisplacementFieldTransformParametersAdaptor
with larger factors to avoid aliasing.
  • Loading branch information
thewtex committed May 22, 2024
1 parent 7a03aab commit b9f2e71
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
19 changes: 17 additions & 2 deletions include/itkANTSRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "itkCompositeTransform.h"
#include "itkDataObjectDecorator.h"
#include "itkantsRegistrationHelper.h"
#include "itkDisplacementFieldTransformParametersAdaptor.h"

namespace itk
{
Expand Down Expand Up @@ -240,11 +241,17 @@ class ANTSRegistration : public ProcessObject
/** Set/Get the optimizer weights. When set, this allows restricting the optimization
* of the displacement field, translation, rigid or affine transform on a per-component basis.
* For example, to limit the deformation or rotation of 3-D volume to the first two dimensions,
* specify a weight vector of ‘(1,1,0)’ for a 3D deformation field
* specify a weight vector of ‘(1,1,0)’ for a 3D displacement field
* or ‘(1,1,0,1,1,0)’ for a rigid transformation. */
itkSetMacro(RestrictTransformation, std::vector<ParametersValueType>);
itkGetConstReferenceMacro(RestrictTransformation, std::vector<ParametersValueType>);

/** Set/Get the subsampling factor for displacement fields results.
* A factor of 1 results in no subsampling. This is applied in all dimensions.
* The default is 2. */
itkSetMacro(DisplacementFieldSubsamplingFactor, unsigned int);
itkGetMacro(DisplacementFieldSubsamplingFactor, unsigned int);

virtual DecoratedOutputTransformType *
GetOutput(DataObjectPointerArraySizeType i);
virtual const DecoratedOutputTransformType *
Expand Down Expand Up @@ -286,6 +293,10 @@ class ANTSRegistration : public ProcessObject
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType) override;
using RegistrationHelperType = ::ants::RegistrationHelper<TParametersValueType, FixedImageType::ImageDimension>;
using InternalImageType = typename RegistrationHelperType::ImageType; // float or double pixels
using DisplacementFieldTransformType = typename RegistrationHelperType::DisplacementFieldTransformType;
using DisplacementFieldType = typename DisplacementFieldTransformType::DisplacementFieldType;
using DisplacementFieldTransformParametersAdaptorType =
DisplacementFieldTransformParametersAdaptor<DisplacementFieldTransformType>;

template <typename TImage>
typename InternalImageType::Pointer
Expand Down Expand Up @@ -346,6 +357,7 @@ class ANTSRegistration : public ProcessObject
unsigned int m_Radius{ 4 };
bool m_CollapseCompositeTransform{ true };
bool m_MaskAllStages{ false };
unsigned int m_DisplacementFieldSubsamplingFactor{ 2 };

std::vector<unsigned int> m_SynIterations{ 40, 20, 0 };
std::vector<unsigned int> m_AffineIterations{ 2100, 1200, 1200, 10 };
Expand All @@ -355,7 +367,10 @@ class ANTSRegistration : public ProcessObject
std::vector<ParametersValueType> m_RestrictTransformation;

private:
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
typename DisplacementFieldTransformParametersAdaptorType::Pointer m_DisplacementFieldAdaptor{
DisplacementFieldTransformParametersAdaptorType::New()
};

#ifdef ITK_USE_CONCEPT_CHECKING
static_assert(TFixedImage::ImageDimension == TMovingImage::ImageDimension,
Expand Down
42 changes: 39 additions & 3 deletions include/itkANTSRegistration.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::PrintSelf(std
os << indent << "Radius: " << this->m_Radius << std::endl;
os << indent << "CollapseCompositeTransform: " << (this->m_CollapseCompositeTransform ? "On" : "Off") << std::endl;
os << indent << "MaskAllStages: " << (this->m_MaskAllStages ? "On" : "Off") << std::endl;
os << indent << "DisplacementFieldSubsamplingFactor: " << this->m_DisplacementFieldSubsamplingFactor << std::endl;

os << indent << "SynIterations: " << this->m_SynIterations << std::endl;
os << indent << "AffineIterations: " << this->m_AffineIterations << std::endl;
Expand Down Expand Up @@ -257,8 +258,8 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::MakeOutput(Da
template <typename TFixedImage, typename TMovingImage, typename TParametersValueType>
template <typename TImage>
auto
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(
const TImage * inputImage) -> typename InternalImageType::Pointer
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(const TImage * inputImage) ->
typename InternalImageType::Pointer
{
using CastFilterType = CastImageFilter<TImage, InternalImageType>;
typename CastFilterType::Pointer castFilter = CastFilterType::New();
Expand Down Expand Up @@ -606,7 +607,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
{
itkExceptionMacro(<< "Unsupported transform type: " << this->GetTypeOfTransform());
}
this->UpdateProgress(0.95);
this->UpdateProgress(0.90);

typename OutputTransformType::Pointer forwardTransform = m_Helper->GetModifiableCompositeTransform();
if (m_CollapseCompositeTransform)
Expand All @@ -615,6 +616,41 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
}
this->SetForwardTransform(forwardTransform);

if (m_DisplacementFieldSubsamplingFactor > 1)
{
using TransformType = typename OutputTransformType::TransformType;
for (unsigned int i = 0; i < forwardTransform->GetNumberOfTransforms(); ++i)
{
typename TransformType::Pointer transform = forwardTransform->GetNthTransform(i);
typename DisplacementFieldTransformType::Pointer displacementFieldTransform =
dynamic_cast<DisplacementFieldTransformType *>(transform.GetPointer());
if (displacementFieldTransform)
{
// The transform is a DisplacementFieldTransform
displacementFieldTransform->Print(std::cout, 3);
const auto displacementField = displacementFieldTransform->GetDisplacementField();
m_DisplacementFieldAdaptor->SetTransform(displacementFieldTransform);
m_DisplacementFieldAdaptor->SetRequiredOrigin(displacementField->GetOrigin());
m_DisplacementFieldAdaptor->SetRequiredDirection(displacementField->GetDirection());
auto requiredSize = displacementField->GetLargestPossibleRegion().GetSize();
for (unsigned int i = 0; i < requiredSize.GetSizeDimension(); ++i)
{
requiredSize[i] /= m_DisplacementFieldSubsamplingFactor;
}
m_DisplacementFieldAdaptor->SetRequiredSize(requiredSize);
auto requiredSpacing = displacementField->GetSpacing();
for (unsigned int i = 0; i < requiredSpacing.GetVectorDimension(); ++i)
{
requiredSpacing[i] *= m_DisplacementFieldSubsamplingFactor;
}
m_DisplacementFieldAdaptor->SetRequiredSpacing(requiredSpacing);
m_DisplacementFieldAdaptor->AdaptTransformParameters();
displacementFieldTransform->Print(std::cout, 3);
}
}
}
this->UpdateProgress(0.95);

typename OutputTransformType::Pointer inverseTransform = OutputTransformType::New();
if (forwardTransform->GetInverse(inverseTransform))
{
Expand Down

0 comments on commit b9f2e71

Please sign in to comment.