Skip to content

Commit

Permalink
Use Symmetric Covariant Axis2Tensor
Browse files Browse the repository at this point in the history
wip
  • Loading branch information
charleskawczynski committed Jun 18, 2024
1 parent 161d952 commit af06f67
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 21 deletions.
59 changes: 59 additions & 0 deletions src/Geometry/Geometry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,65 @@ export Contravariant1Vector,
Contravariant123Vector


"""
SimpleSymmetric
A simple Symmetric matrix. `LinearAlgebra.Symmetric` has a field
`uplo::Char`, which results in a failed `check_basetype(T, S)`.
"""
struct SimpleSymmetric{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T}
data::S
function SimpleSymmetric{T, S}(data) where {T, S <: AbstractMatrix{<:T}}
new{T, S}(data)
end
end
SimpleSymmetric(A) = cc_symmetric_type(typeof(A))(A)

LinearAlgebra.transpose(A::SimpleSymmetric) = A

"""
cc_symmetric_type(T::Type)
The type of the object returned by `symmetric(::T, ::Symbol)`. For matrices, this is an
appropriately typed `Symmetric`, for `Number`s, it is the original type. If `symmetric` is
implemented for a custom type, so should be `cc_symmetric_type`, and vice versa.
"""
cc_symmetric_type(::Type{T}) where {S, T <: AbstractMatrix{S}} =
SimpleSymmetric{
Union{S, Base.promote_op(transpose, S), cc_symmetric_type(S)},
T,
}
cc_symmetric_type(::Type{T}) where {S <: Number, T <: AbstractMatrix{S}} =
SimpleSymmetric{S, T}
cc_symmetric_type(
::Type{T},
) where {S <: AbstractMatrix, T <: AbstractMatrix{S}} =
SimpleSymmetric{AbstractMatrix, T}
cc_symmetric_type(::Type{T}) where {T <: Number} = T

Base.@propagate_inbounds function Base.getindex(
A::SimpleSymmetric,
i::Integer,
j::Integer,
)
@boundscheck checkbounds(A, i, j)
@inbounds if i == j
data_ij = A.data[i, j]
if data_ij isa AbstractMatrix
return SimpleSymmetric(data_ij)::cc_symmetric_type(eltype(A.data))
elseif data_ij isa Number
return data_ij::cc_symmetric_type(eltype(A.data))
end
elseif (i < j)
return A.data[i, j]
else
return LinearAlgebra.transpose(A.data[j, i])
end
end

Base.@propagate_inbounds Base.getindex(A::SimpleSymmetric, i::Integer) =
A.data[i]


include("coordinates.jl")
include("axistensors.jl")
Expand Down
10 changes: 8 additions & 2 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ struct AxisTensor{
T,
N,
A <: NTuple{N, AbstractAxis},
S <: StaticArray{<:Tuple, T, N},
S <: Union{
SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
StaticArray{<:Tuple, T, N},
},
} <: AbstractArray{T, N}
axes::A
components::S
Expand All @@ -147,7 +150,10 @@ AxisTensor(
components::S,
) where {
A <: Tuple{Vararg{AbstractAxis}},
S <: StaticArray{<:Tuple, T, N},
S <: Union{
SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
StaticArray{<:Tuple, T, N},
},
} where {T, N} = AxisTensor{T, N, A, S}(axes, components)

AxisTensor(axes::Tuple{Vararg{AbstractAxis}}, components) =
Expand Down
29 changes: 18 additions & 11 deletions src/Geometry/localgeometry.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import LinearAlgebra
import .Geometry: SimpleSymmetric

import InteractiveUtils

"""
LocalGeometry
Expand All @@ -20,7 +24,11 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S}
"Contravariant metric tensor (inverse of gᵢⱼ), transforms covariant to contravariant vector components"
gⁱʲ::Axis2Tensor{FT, Tuple{ContravariantAxis{I}, ContravariantAxis{I}}, S}
"Covariant metric tensor (gᵢⱼ), transforms contravariant to covariant vector components"
gᵢⱼ::Axis2Tensor{FT, Tuple{CovariantAxis{I}, CovariantAxis{I}}, S}
gᵢⱼ::Axis2Tensor{
FT,
Tuple{CovariantAxis{I}, CovariantAxis{I}},
SimpleSymmetric{FT, S},
}
@inline function LocalGeometry(
coordinates,
J,
Expand All @@ -30,20 +38,19 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S}
∂ξ∂x = inv(∂x∂ξ)
C = typeof(coordinates)
Jinv = inv(J)
return new{I, C, FT, S}(
coordinates,
J,
WJ,
Jinv,
∂x∂ξ,
∂ξ∂x,
∂ξ∂x * ∂ξ∂x',
∂x∂ξ' * ∂x∂ξ,
gᵢⱼ₀ = ∂x∂ξ' * ∂x∂ξ
gⁱʲ = ∂ξ∂x * ∂ξ∂x'
@assert LinearAlgebra.issymmetric(components(gᵢⱼ₀))
@assert LinearAlgebra.issymmetric(components(gⁱʲ))
ˢgᵢⱼ = SimpleSymmetric(components(gᵢⱼ₀))
gᵢⱼ = AxisTensor{FT, 2, typeof(axes(gᵢⱼ₀)), typeof(ˢgᵢⱼ)}(
axes(gᵢⱼ₀),
ˢgᵢⱼ,
)
return new{I, C, FT, S}(coordinates, J, WJ, Jinv, ∂x∂ξ, ∂ξ∂x, gⁱʲ, gᵢⱼ)
end
end


"""
SurfaceGeometry
Expand Down
15 changes: 7 additions & 8 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,16 +708,15 @@ end
Geometry.Cartesian123Point(x1, x2, x3),
]
all_components = [
SMatrix{1, 1, FT}(range(1, 1)...),
SMatrix{2, 2, FT}(range(1, 4)...),
SMatrix{3, 3, FT}(range(1, 9)...),
SMatrix{3, 3, FT}(range(1, 9)...),
SMatrix{1, 1, FT}(range(1, 1)...),
SMatrix{2, 2, FT}(range(1, 4)...),
SMatrix{3, 3, FT}(range(1, 9)...),
SMatrix{1, 1}(FT[1]),
SMatrix{2, 2}(FT[1 2; 3 4]),
SMatrix{3, 3}(FT[1 2 10; 4 5 6; 7 8 9]),
SMatrix{3, 3}(FT[1 2 10; 4 5 6; 7 8 9]),
SMatrix{2, 2}(FT[1 2; 3 4]),
SMatrix{3, 3}(FT[1 2 10; 4 5 6; 7 8 9]),
]

expected_dzs = [1.0, 4.0, 9.0, 9.0, 1.0, 4.0, 9.0]
expected_dzs = [1.0, 4.0, 9.0, 9.0, 1.0, 2.0, 9.0]

for (components, coord, expected_dz) in
zip(all_components, coords, expected_dzs)
Expand Down
4 changes: 4 additions & 0 deletions test/Geometry/axistensor_conversion_benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#=
julia --project
using Revise; include(joinpath("test", "Geometry", "axistensor_conversion_benchmarks.jl"))
=#
using Test, StaticArrays
#! format: off
import Random, BenchmarkTools, StatsBase,
Expand Down

0 comments on commit af06f67

Please sign in to comment.