Skip to content

Commit

Permalink
Merge pull request #31 from thewtex/displacement-field-subsampling-fa…
Browse files Browse the repository at this point in the history
…ctor

ENH: Add DisplacementFieldSubsamplingFactor
  • Loading branch information
thewtex authored May 23, 2024
2 parents afd998c + f92efa9 commit 966cec0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
19 changes: 17 additions & 2 deletions include/itkANTSRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "itkCompositeTransform.h"
#include "itkDataObjectDecorator.h"
#include "itkantsRegistrationHelper.h"
#include "itkDisplacementFieldTransformParametersAdaptor.h"

namespace itk
{
Expand Down Expand Up @@ -240,11 +241,17 @@ class ANTSRegistration : public ProcessObject
/** Set/Get the optimizer weights. When set, this allows restricting the optimization
* of the displacement field, translation, rigid or affine transform on a per-component basis.
* For example, to limit the deformation or rotation of 3-D volume to the first two dimensions,
* specify a weight vector of ‘(1,1,0)’ for a 3D deformation field
* specify a weight vector of ‘(1,1,0)’ for a 3D displacement field
* or ‘(1,1,0,1,1,0)’ for a rigid transformation. */
itkSetMacro(RestrictTransformation, std::vector<ParametersValueType>);
itkGetConstReferenceMacro(RestrictTransformation, std::vector<ParametersValueType>);

/** Set/Get the subsampling factor for displacement fields results.
* A factor of 1 results in no subsampling. This is applied in all dimensions.
* The default is 2. */
itkSetMacro(DisplacementFieldSubsamplingFactor, unsigned int);
itkGetMacro(DisplacementFieldSubsamplingFactor, unsigned int);

virtual DecoratedOutputTransformType *
GetOutput(DataObjectPointerArraySizeType i);
virtual const DecoratedOutputTransformType *
Expand Down Expand Up @@ -286,6 +293,10 @@ class ANTSRegistration : public ProcessObject
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType) override;
using RegistrationHelperType = ::ants::RegistrationHelper<TParametersValueType, FixedImageType::ImageDimension>;
using InternalImageType = typename RegistrationHelperType::ImageType; // float or double pixels
using DisplacementFieldTransformType = typename RegistrationHelperType::DisplacementFieldTransformType;
using DisplacementFieldType = typename DisplacementFieldTransformType::DisplacementFieldType;
using DisplacementFieldTransformParametersAdaptorType =
DisplacementFieldTransformParametersAdaptor<DisplacementFieldTransformType>;

template <typename TImage>
typename InternalImageType::Pointer
Expand Down Expand Up @@ -346,6 +357,7 @@ class ANTSRegistration : public ProcessObject
unsigned int m_Radius{ 4 };
bool m_CollapseCompositeTransform{ true };
bool m_MaskAllStages{ false };
unsigned int m_DisplacementFieldSubsamplingFactor{ 2 };

std::vector<unsigned int> m_SynIterations{ 40, 20, 0 };
std::vector<unsigned int> m_AffineIterations{ 2100, 1200, 1200, 10 };
Expand All @@ -355,7 +367,10 @@ class ANTSRegistration : public ProcessObject
std::vector<ParametersValueType> m_RestrictTransformation;

private:
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
typename RegistrationHelperType::Pointer m_Helper{ RegistrationHelperType::New() };
typename DisplacementFieldTransformParametersAdaptorType::Pointer m_DisplacementFieldAdaptor{
DisplacementFieldTransformParametersAdaptorType::New()
};

#ifdef ITK_USE_CONCEPT_CHECKING
static_assert(TFixedImage::ImageDimension == TMovingImage::ImageDimension,
Expand Down
42 changes: 39 additions & 3 deletions include/itkANTSRegistration.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::PrintSelf(std
os << indent << "Radius: " << this->m_Radius << std::endl;
os << indent << "CollapseCompositeTransform: " << (this->m_CollapseCompositeTransform ? "On" : "Off") << std::endl;
os << indent << "MaskAllStages: " << (this->m_MaskAllStages ? "On" : "Off") << std::endl;
os << indent << "DisplacementFieldSubsamplingFactor: " << this->m_DisplacementFieldSubsamplingFactor << std::endl;

os << indent << "SynIterations: " << this->m_SynIterations << std::endl;
os << indent << "AffineIterations: " << this->m_AffineIterations << std::endl;
Expand Down Expand Up @@ -257,8 +258,8 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::MakeOutput(Da
template <typename TFixedImage, typename TMovingImage, typename TParametersValueType>
template <typename TImage>
auto
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(
const TImage * inputImage) -> typename InternalImageType::Pointer
ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::CastImageToInternalType(const TImage * inputImage) ->
typename InternalImageType::Pointer
{
using CastFilterType = CastImageFilter<TImage, InternalImageType>;
typename CastFilterType::Pointer castFilter = CastFilterType::New();
Expand Down Expand Up @@ -606,7 +607,7 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
{
itkExceptionMacro(<< "Unsupported transform type: " << this->GetTypeOfTransform());
}
this->UpdateProgress(0.95);
this->UpdateProgress(0.90);

typename OutputTransformType::Pointer forwardTransform = m_Helper->GetModifiableCompositeTransform();
if (m_CollapseCompositeTransform)
Expand All @@ -615,6 +616,41 @@ ANTSRegistration<TFixedImage, TMovingImage, TParametersValueType>::GenerateData(
}
this->SetForwardTransform(forwardTransform);

if (m_DisplacementFieldSubsamplingFactor > 1)
{
using TransformType = typename OutputTransformType::TransformType;
for (unsigned int i = 0; i < forwardTransform->GetNumberOfTransforms(); ++i)
{
typename TransformType::Pointer transform = forwardTransform->GetNthTransform(i);
typename DisplacementFieldTransformType::Pointer displacementFieldTransform =
dynamic_cast<DisplacementFieldTransformType *>(transform.GetPointer());
if (displacementFieldTransform)
{
// The transform is a DisplacementFieldTransform
displacementFieldTransform->Print(std::cout, 3);
const auto displacementField = displacementFieldTransform->GetDisplacementField();
m_DisplacementFieldAdaptor->SetTransform(displacementFieldTransform);
m_DisplacementFieldAdaptor->SetRequiredOrigin(displacementField->GetOrigin());
m_DisplacementFieldAdaptor->SetRequiredDirection(displacementField->GetDirection());
auto requiredSize = displacementField->GetLargestPossibleRegion().GetSize();
for (unsigned int i = 0; i < requiredSize.GetSizeDimension(); ++i)
{
requiredSize[i] /= m_DisplacementFieldSubsamplingFactor;
}
m_DisplacementFieldAdaptor->SetRequiredSize(requiredSize);
auto requiredSpacing = displacementField->GetSpacing();
for (unsigned int i = 0; i < requiredSpacing.GetVectorDimension(); ++i)
{
requiredSpacing[i] *= m_DisplacementFieldSubsamplingFactor;
}
m_DisplacementFieldAdaptor->SetRequiredSpacing(requiredSpacing);
m_DisplacementFieldAdaptor->AdaptTransformParameters();
displacementFieldTransform->Print(std::cout, 3);
}
}
}
this->UpdateProgress(0.95);

typename OutputTransformType::Pointer inverseTransform = OutputTransformType::New();
if (forwardTransform->GetInverse(inverseTransform))
{
Expand Down

0 comments on commit 966cec0

Please sign in to comment.