Skip to content

Commit

Permalink
ENH: default projectors for all iterative filters
Browse files Browse the repository at this point in the history
  • Loading branch information
acoussat committed Apr 14, 2020
1 parent bc6d102 commit 2c8ae75
Show file tree
Hide file tree
Showing 26 changed files with 119 additions and 377 deletions.
9 changes: 0 additions & 9 deletions include/rtkADMMTotalVariationConeBeamReconstructionFilter.h
Expand Up @@ -174,15 +174,6 @@ class ADMMTotalVariationConeBeamReconstructionFilter
using DisplacedDetectorFilterType = rtk::DisplacedDetectorImageFilter<TOutputImage>;
using GatingWeightsFilterType = rtk::MultiplyByVectorImageFilter<TOutputImage>;

/** Pass the ForwardProjection filter to the conjugate gradient operator */
void
SetForwardProjectionFilter(ForwardProjectionType _arg) override;

/** Pass the backprojection filter to the conjugate gradient operator and to the back projection filter generating the
* B of AX=B */
void
SetBackProjectionFilter(BackProjectionType _arg) override;

/** Pass the geometry to all filters needing it */
itkSetObjectMacro(Geometry, ThreeDCircularProjectionGeometry);

Expand Down
40 changes: 13 additions & 27 deletions include/rtkADMMTotalVariationConeBeamReconstructionFilter.hxx
Expand Up @@ -95,33 +95,6 @@ ADMMTotalVariationConeBeamReconstructionFilter<TOutputImage,
m_DisplacedDetectorFilter->ReleaseDataFlagOn();
}

template <typename TOutputImage, typename TGradientOutputImage>
void
ADMMTotalVariationConeBeamReconstructionFilter<TOutputImage, TGradientOutputImage>::SetForwardProjectionFilter(
ForwardProjectionType _arg)
{
if (_arg != this->GetForwardProjectionFilter())
{
Superclass::SetForwardProjectionFilter(_arg);
m_ForwardProjectionFilter = this->InstantiateForwardProjectionFilter(_arg);
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilter);
}
}

template <typename TOutputImage, typename TGradientOutputImage>
void
ADMMTotalVariationConeBeamReconstructionFilter<TOutputImage, TGradientOutputImage>::SetBackProjectionFilter(
BackProjectionType _arg)
{
if (_arg != this->GetBackProjectionFilter())
{
Superclass::SetBackProjectionFilter(_arg);
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(_arg);
m_BackProjectionFilterForConjugateGradient = this->InstantiateBackProjectionFilter(_arg);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilterForConjugateGradient);
}
}

template <typename TOutputImage, typename TGradientOutputImage>
void
ADMMTotalVariationConeBeamReconstructionFilter<TOutputImage, TGradientOutputImage>::SetBetaForCurrentIteration(int iter)
Expand Down Expand Up @@ -177,6 +150,19 @@ template <typename TOutputImage, typename TGradientOutputImage>
void
ADMMTotalVariationConeBeamReconstructionFilter<TOutputImage, TGradientOutputImage>::GenerateOutputInformation()
{
// Set forward projection filter
m_ForwardProjectionFilter = this->InstantiateForwardProjectionFilter(this->m_CurrentForwardProjectionConfiguration);
// Pass the ForwardProjection filter to the conjugate gradient operator
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilter);

// Set back projection filter
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
// Pass the backprojection filter to the conjugate gradient operator and to the back projection filter generating the
// B of AX=B
m_BackProjectionFilterForConjugateGradient =
this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilterForConjugateGradient);

// Set runtime connections
m_GradientFilter1->SetInput(this->GetInput(0));
m_ZeroMultiplyVolumeFilter->SetInput1(this->GetInput(0));
Expand Down
9 changes: 0 additions & 9 deletions include/rtkADMMWaveletsConeBeamReconstructionFilter.h
Expand Up @@ -176,15 +176,6 @@ class ADMMWaveletsConeBeamReconstructionFilter
using ForwardProjectionType = typename Superclass::ForwardProjectionType;
using BackProjectionType = typename Superclass::BackProjectionType;

/** Pass the ForwardProjection filter to the conjugate gradient operator */
void
SetForwardProjectionFilter(ForwardProjectionType _arg) override;

/** Pass the backprojection filter to the conjugate gradient operator and to the back projection filter generating the
* B of AX=B */
void
SetBackProjectionFilter(BackProjectionType _arg) override;

/** Pass the geometry to all filters needing it */
itkSetObjectMacro(Geometry, ThreeDCircularProjectionGeometry);

Expand Down
39 changes: 14 additions & 25 deletions include/rtkADMMWaveletsConeBeamReconstructionFilter.hxx
Expand Up @@ -72,31 +72,6 @@ ADMMWaveletsConeBeamReconstructionFilter<TOutputImage>::ADMMWaveletsConeBeamReco
m_DisplacedDetectorFilter->ReleaseDataFlagOn();
}

template <typename TOutputImage>
void
ADMMWaveletsConeBeamReconstructionFilter<TOutputImage>::SetForwardProjectionFilter(ForwardProjectionType _arg)
{
if (_arg != this->GetForwardProjectionFilter())
{
Superclass::SetForwardProjectionFilter(_arg);
m_ForwardProjectionFilterForConjugateGradient = this->InstantiateForwardProjectionFilter(_arg);
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilterForConjugateGradient);
}
}

template <typename TOutputImage>
void
ADMMWaveletsConeBeamReconstructionFilter<TOutputImage>::SetBackProjectionFilter(BackProjectionType _arg)
{
if (_arg != this->GetBackProjectionFilter())
{
Superclass::SetBackProjectionFilter(_arg);
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(_arg);
m_BackProjectionFilterForConjugateGradient = this->InstantiateBackProjectionFilter(_arg);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilterForConjugateGradient);
}
}

template <class TOutputImage>
void
ADMMWaveletsConeBeamReconstructionFilter<TOutputImage>::VerifyPreconditions() ITKv5_CONST
Expand Down Expand Up @@ -132,6 +107,20 @@ template <typename TOutputImage>
void
ADMMWaveletsConeBeamReconstructionFilter<TOutputImage>::GenerateOutputInformation()
{
// Set forward projection filter
m_ForwardProjectionFilterForConjugateGradient =
this->InstantiateForwardProjectionFilter(this->m_CurrentForwardProjectionConfiguration);
// Pass the ForwardProjection filter to the conjugate gradient operator
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilterForConjugateGradient);

// Set back projection filter
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
// Pass the backprojection filter to the conjugate gradient operator and to the back projection filter generating the
// B of AX=B
m_BackProjectionFilterForConjugateGradient =
this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilterForConjugateGradient);

// Set runtime connections
m_ZeroMultiplyFilter->SetInput1(this->GetInput(0));
m_CGOperator->SetInput(1, this->GetInput(1)); // The projections (the conjugate gradient operator needs them)
Expand Down
9 changes: 0 additions & 9 deletions include/rtkConjugateGradientConeBeamReconstructionFilter.h
Expand Up @@ -170,15 +170,6 @@ class ConjugateGradientConeBeamReconstructionFilter : public IterativeConeBeamRe
using ConstantImageSourceType = ConstantImageSource<TOutputImage>;
#endif

/** Pass the ForwardProjection filter to the conjugate gradient operator */
void
SetForwardProjectionFilter(ForwardProjectionType _arg) override;

/** Pass the backprojection filter to the conjugate gradient operator and to the back projection filter generating the
* B of AX=B */
void
SetBackProjectionFilter(BackProjectionType _arg) override;

/** Set the support mask, if any, for support constraint in reconstruction */
void
SetSupportMask(const TSingleComponentImage * SupportMask);
Expand Down
39 changes: 12 additions & 27 deletions include/rtkConjugateGradientConeBeamReconstructionFilter.hxx
Expand Up @@ -123,33 +123,6 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
return m_ConjugateGradientFilter->GetResidualCosts();
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
void
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::
SetForwardProjectionFilter(ForwardProjectionType _arg)
{
if (_arg != this->GetForwardProjectionFilter())
{
Superclass::SetForwardProjectionFilter(_arg);
m_ForwardProjectionFilter = this->InstantiateForwardProjectionFilter(_arg);
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilter);
}
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
void
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::
SetBackProjectionFilter(BackProjectionType _arg)
{
if (_arg != this->GetBackProjectionFilter())
{
Superclass::SetBackProjectionFilter(_arg);
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(_arg);
m_BackProjectionFilterForB = this->InstantiateBackProjectionFilter(_arg);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilter);
}
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
void
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::VerifyPreconditions()
Expand Down Expand Up @@ -209,6 +182,18 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
m_ConjugateGradientFilter->SetA(m_CGOperator.GetPointer());
m_ConjugateGradientFilter->SetIterationCosts(m_IterationCosts);

// Set forward projection filter
m_ForwardProjectionFilter = this->InstantiateForwardProjectionFilter(this->m_CurrentForwardProjectionConfiguration);
// Pass the ForwardProjection filter to the conjugate gradient operator
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilter);

// Set back projection filter
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
// Pass the backprojection filter to the conjugate gradient operator and to the back projection filter generating the
// B of AX=B
m_BackProjectionFilterForB = this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilter);

// Set runtime connections
m_ConstantVolumeSource->SetInformationFromImage(this->GetInputVolume());
m_CGOperator->SetInputProjectionStack(this->GetInputProjectionStack());
Expand Down
Expand Up @@ -163,14 +163,6 @@ class ITK_EXPORT FourDConjugateGradientConeBeamReconstructionFilter
typename ProjectionStackType::ConstPointer
GetInputProjectionStack();

/** Pass the ForwardProjection filter to the conjugate gradient operator */
void
SetForwardProjectionFilter(ForwardProjectionType _arg) override;

/** Pass the backprojection filter to the conjugate gradient operator and to the filter generating the B of AX=B */
void
SetBackProjectionFilter(BackProjectionType _arg) override;

/** Pass the interpolation weights to subfilters */
void
SetWeights(const itk::Array2D<float> _arg);
Expand Down
73 changes: 29 additions & 44 deletions include/rtkFourDConjugateGradientConeBeamReconstructionFilter.hxx
Expand Up @@ -87,50 +87,6 @@ FourDConjugateGradientConeBeamReconstructionFilter<VolumeSeriesType, ProjectionS
return static_cast<const ProjectionStackType *>(this->itk::ProcessObject::GetInput(1));
}

template <class VolumeSeriesType, class ProjectionStackType>
void
FourDConjugateGradientConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>::SetForwardProjectionFilter(
ForwardProjectionType _arg)
{
if (_arg != this->GetForwardProjectionFilter())
{
Superclass::SetForwardProjectionFilter(_arg);
m_ForwardProjectionFilter = this->InstantiateForwardProjectionFilter(_arg);
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilter);
}
if (_arg == 2) // The forward projection filter runs on GPU. It is most efficient to also run the interpolation on
// GPU, and to use GPU constant image sources
{
m_CGOperator->SetUseCudaInterpolation(true);
m_CGOperator->SetUseCudaSources(true);
}
}


template <class VolumeSeriesType, class ProjectionStackType>
void
FourDConjugateGradientConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>::SetBackProjectionFilter(
BackProjectionType _arg)
{
if (_arg != this->GetBackProjectionFilter())
{
Superclass::SetBackProjectionFilter(_arg);
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(_arg);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilter);

m_BackProjectionFilterForB = this->InstantiateBackProjectionFilter(_arg);
m_ProjStackToFourDFilter->SetBackProjectionFilter(m_BackProjectionFilterForB);
}
if (_arg == 2) // The back projection filter runs on GPU. It is most efficient to also run the splat on GPU, and to
// use GPU constant image sources
{
m_CGOperator->SetUseCudaSplat(true);
m_CGOperator->SetUseCudaSources(true);
m_ProjStackToFourDFilter->SetUseCudaSplat(true);
m_ProjStackToFourDFilter->SetUseCudaSources(true);
}
}

template <class VolumeSeriesType, class ProjectionStackType>
void
FourDConjugateGradientConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>::SetWeights(
Expand Down Expand Up @@ -198,6 +154,35 @@ FourDConjugateGradientConeBeamReconstructionFilter<VolumeSeriesType, ProjectionS
m_DisplacedDetectorFilter->SetDisable(m_DisableDisplacedDetectorFilter);
m_CGOperator->SetDisableDisplacedDetectorFilter(m_DisableDisplacedDetectorFilter);

// Set forward projection filter
m_ForwardProjectionFilter = this->InstantiateForwardProjectionFilter(this->m_CurrentForwardProjectionConfiguration);
// Pass the ForwardProjection filter to the conjugate gradient operator
m_CGOperator->SetForwardProjectionFilter(m_ForwardProjectionFilter);
if (this->m_CurrentForwardProjectionConfiguration ==
ForwardProjectionType::FP_CUDARAYCAST) // The forward projection filter runs on GPU. It is most efficient to also
// run the interpolation on GPU, and to use GPU constant image sources
{
m_CGOperator->SetUseCudaInterpolation(true);
m_CGOperator->SetUseCudaSources(true);
}

// Set back projection filter
m_BackProjectionFilter = this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
m_CGOperator->SetBackProjectionFilter(m_BackProjectionFilter);

m_BackProjectionFilterForB = this->InstantiateBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);
// Pass the backprojection filter to the conjugate gradient operator and to the filter generating the B of AX=B
m_ProjStackToFourDFilter->SetBackProjectionFilter(m_BackProjectionFilterForB);
if (this->m_CurrentBackProjectionConfiguration ==
BackProjectionType::BP_CUDAVOXELBASED) // The back projection filter runs on GPU. It is most efficient to also run
// the splat on GPU, and to use GPU constant image sources
{
m_CGOperator->SetUseCudaSplat(true);
m_CGOperator->SetUseCudaSources(true);
m_ProjStackToFourDFilter->SetUseCudaSplat(true);
m_ProjStackToFourDFilter->SetUseCudaSources(true);
}

// Have the last filter calculate its output information
m_ConjugateGradientFilter->UpdateOutputInformation();

Expand Down
8 changes: 0 additions & 8 deletions include/rtkFourDROOSTERConeBeamReconstructionFilter.h
Expand Up @@ -310,14 +310,6 @@ class FourDROOSTERConeBeamReconstructionFilter
using ForwardProjectionType = typename Superclass::ForwardProjectionType;
using BackProjectionType = typename Superclass::BackProjectionType;

/** Pass the ForwardProjection filter to SingleProjectionToFourDFilter */
void
SetForwardProjectionFilter(ForwardProjectionType _arg) override;

/** Pass the backprojection filter to ProjectionStackToFourD*/
void
SetBackProjectionFilter(BackProjectionType _arg) override;

/** Pass the interpolation weights to SingleProjectionToFourDFilter */
virtual void
SetWeights(const itk::Array2D<float> _arg);
Expand Down
28 changes: 4 additions & 24 deletions include/rtkFourDROOSTERConeBeamReconstructionFilter.hxx
Expand Up @@ -158,30 +158,6 @@ FourDROOSTERConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>:
return static_cast<DVFSequenceImageType *>(this->itk::ProcessObject::GetInput("InverseDisplacementField"));
}

template <typename VolumeSeriesType, typename ProjectionStackType>
void
FourDROOSTERConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>::SetForwardProjectionFilter(
ForwardProjectionType _arg)
{
if (_arg != this->GetForwardProjectionFilter())
{
Superclass::SetForwardProjectionFilter(_arg);
m_FourDCGFilter->SetForwardProjectionFilter(_arg);
}
}

template <typename VolumeSeriesType, typename ProjectionStackType>
void
FourDROOSTERConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>::SetBackProjectionFilter(
BackProjectionType _arg)
{
if (_arg != this->GetBackProjectionFilter())
{
Superclass::SetBackProjectionFilter(_arg);
m_FourDCGFilter->SetBackProjectionFilter(_arg);
}
}

template <typename VolumeSeriesType, typename ProjectionStackType>
void
FourDROOSTERConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>::SetWeights(
Expand Down Expand Up @@ -266,6 +242,10 @@ FourDROOSTERConeBeamReconstructionFilter<VolumeSeriesType, ProjectionStackType>:
{
const int Dimension = VolumeType::ImageDimension;

// Set projection filters
m_FourDCGFilter->SetForwardProjectionFilter(this->m_CurrentForwardProjectionConfiguration);
m_FourDCGFilter->SetBackProjectionFilter(this->m_CurrentBackProjectionConfiguration);

// The 4D conjugate gradient filter is the only part that must be in the pipeline
// whatever was the user wants
m_FourDCGFilter->SetInputVolumeSeries(this->GetInputVolumeSeries());
Expand Down
8 changes: 0 additions & 8 deletions include/rtkFourDSARTConeBeamReconstructionFilter.h
Expand Up @@ -196,14 +196,6 @@ class ITK_EXPORT FourDSARTConeBeamReconstructionFilter
itkGetMacro(EnforcePositivity, bool);
itkSetMacro(EnforcePositivity, bool);

/** Select the ForwardProjection filter */
void
SetForwardProjectionFilter(ForwardProjectionType _arg) override;

/** Select the backprojection filter */
void
SetBackProjectionFilter(BackProjectionType _arg) override;

/** Pass the interpolation weights to subfilters */
void
SetWeights(const itk::Array2D<float> _arg);
Expand Down

0 comments on commit 2c8ae75

Please sign in to comment.