From 64969bd059d261dd5add4b87fdb0a65664224a77 Mon Sep 17 00:00:00 2001 From: Milan Date: Fri, 2 Sep 2022 13:49:25 +0100 Subject: [PATCH] plan_(b)rfft types as in FFTW --- src/fft.jl | 12 ++++++------ test/fft_tests.jl | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/fft.jl b/src/fft.jl index 9907e1a..0abd5c6 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -286,11 +286,11 @@ end # This is the reason for using StridedArray below. We also have to carefully # distinguish between real and complex arguments. -plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},false,typeof(region)}(region) -plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{Complex{real(T)},true,typeof(region)}(region) +plan_fft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{T,false,typeof(region)}(region) +plan_fft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyFFTPlan{T,true,typeof(region)}(region) -plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},false,typeof(region)}(region) -plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{Complex{real(T)},true,typeof(region)}(region) +plan_bfft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{T,false,typeof(region)}(region) +plan_bfft!(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummybFFTPlan{T,true,typeof(region)}(region) # The ifft plans are automatically provided in terms of the bfft plans above. # plan_ifft(x::StridedArray{T}, region) where {T <: ComplexFloats} = DummyiFFTPlan{Complex{real(T)},false,typeof(region)}(region) @@ -302,8 +302,8 @@ plan_dct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyDCTPlan plan_idct(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,false,typeof(region)}(region) plan_idct!(x::StridedArray{T}, region) where {T <: AbstractFloats} = DummyiDCTPlan{T,true,typeof(region)}(region) -plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{Complex{real(T)},false,typeof(region)}(length(x), region) -plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{Complex{real(T)},false,typeof(region)}(n, region) +plan_rfft(x::StridedArray{T}, region) where {T <: RealFloats} = DummyrFFTPlan{T,false,typeof(region)}(length(x), region) +plan_brfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummybrFFTPlan{T,false,typeof(region)}(n, region) # A plan for irfft is created in terms of a plan for brfft. # plan_irfft(x::StridedArray{T}, n::Integer, region) where {T <: ComplexFloats} = DummyirFFTPlan{Complex{real(T)},false,typeof(region)}(n, region) diff --git a/test/fft_tests.jl b/test/fft_tests.jl index c44414d..556067a 100644 --- a/test/fft_tests.jl +++ b/test/fft_tests.jl @@ -118,6 +118,13 @@ function test_fftw() @test !( plan_rfft(rand(T,10), 1:1) isa GenericFFT.DummyPlan ) @test !( plan_brfft(rand(Complex{T},10), 19) isa GenericFFT.DummyPlan ) @test !( plan_brfft(rand(Complex{T},10), 19, 1:1) isa GenericFFT.DummyPlan ) + + # check that GenericFFT and FFTW plans have the same parametric type + @test plan_rfft(rand(Float16,10)) isa AbstractFFTs.Plan{Float16} + @test plan_rfft(rand(Float64,10)) isa AbstractFFTs.Plan{Float64} + + @test plan_brfft(rand(Complex{Float16},10),19) isa AbstractFFTs.Plan{Complex{Float16}} + @test plan_brfft(rand(Complex{Float64},10),19) isa AbstractFFTs.Plan{Complex{Float64}} end end