diff --git a/src/Geometry/Geometry.jl b/src/Geometry/Geometry.jl index ea226683bb..4f676692c4 100644 --- a/src/Geometry/Geometry.jl +++ b/src/Geometry/Geometry.jl @@ -23,6 +23,65 @@ export Contravariant1Vector, Contravariant123Vector +""" + SimpleSymmetric + +A simple Symmetric matrix. `LinearAlgebra.Symmetric` has a field +`uplo::Char`, which results in a failed `check_basetype(T, S)`. +""" +struct SimpleSymmetric{T, S <: AbstractMatrix{<:T}} <: AbstractMatrix{T} + data::S + function SimpleSymmetric{T, S}(data) where {T, S <: AbstractMatrix{<:T}} + new{T, S}(data) + end +end +SimpleSymmetric(A) = cc_symmetric_type(typeof(A))(A) + +LinearAlgebra.transpose(A::SimpleSymmetric) = A + +""" + cc_symmetric_type(T::Type) + +The type of the object returned by `symmetric(::T, ::Symbol)`. For matrices, this is an +appropriately typed `Symmetric`, for `Number`s, it is the original type. If `symmetric` is +implemented for a custom type, so should be `cc_symmetric_type`, and vice versa. +""" +cc_symmetric_type(::Type{T}) where {S, T <: AbstractMatrix{S}} = + SimpleSymmetric{ + Union{S, Base.promote_op(transpose, S), cc_symmetric_type(S)}, + T, + } +cc_symmetric_type(::Type{T}) where {S <: Number, T <: AbstractMatrix{S}} = + SimpleSymmetric{S, T} +cc_symmetric_type( + ::Type{T}, +) where {S <: AbstractMatrix, T <: AbstractMatrix{S}} = + SimpleSymmetric{AbstractMatrix, T} +cc_symmetric_type(::Type{T}) where {T <: Number} = T + +Base.@propagate_inbounds function Base.getindex( + A::SimpleSymmetric, + i::Integer, + j::Integer, +) + @boundscheck checkbounds(A, i, j) + @inbounds if i == j + data_ij = A.data[i, j] + if data_ij isa AbstractMatrix + return SimpleSymmetric(data_ij)::cc_symmetric_type(eltype(A.data)) + elseif data_ij isa Number + return data_ij::cc_symmetric_type(eltype(A.data)) + end + elseif (i < j) + return A.data[i, j] + else + return LinearAlgebra.transpose(A.data[j, i]) + end +end + +Base.@propagate_inbounds Base.getindex(A::SimpleSymmetric, i::Integer) = + A.data[i] + include("coordinates.jl") include("axistensors.jl") diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index 534e628380..8eb9134839 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -136,7 +136,10 @@ struct AxisTensor{ T, N, A <: NTuple{N, AbstractAxis}, - S <: StaticArray{<:Tuple, T, N}, + S <: Union{ + SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}}, + StaticArray{<:Tuple, T, N}, + }, } <: AbstractArray{T, N} axes::A components::S @@ -147,7 +150,10 @@ AxisTensor( components::S, ) where { A <: Tuple{Vararg{AbstractAxis}}, - S <: StaticArray{<:Tuple, T, N}, + S <: Union{ + SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}}, + StaticArray{<:Tuple, T, N}, + }, } where {T, N} = AxisTensor{T, N, A, S}(axes, components) AxisTensor(axes::Tuple{Vararg{AbstractAxis}}, components) = diff --git a/src/Geometry/localgeometry.jl b/src/Geometry/localgeometry.jl index a63efe81c9..94a644ee34 100644 --- a/src/Geometry/localgeometry.jl +++ b/src/Geometry/localgeometry.jl @@ -1,3 +1,7 @@ +import LinearAlgebra +import .Geometry: SimpleSymmetric + +import InteractiveUtils """ LocalGeometry @@ -20,7 +24,11 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S} "Contravariant metric tensor (inverse of gᵢⱼ), transforms covariant to contravariant vector components" gⁱʲ::Axis2Tensor{FT, Tuple{ContravariantAxis{I}, ContravariantAxis{I}}, S} "Covariant metric tensor (gᵢⱼ), transforms contravariant to covariant vector components" - gᵢⱼ::Axis2Tensor{FT, Tuple{CovariantAxis{I}, CovariantAxis{I}}, S} + gᵢⱼ::Axis2Tensor{ + FT, + Tuple{CovariantAxis{I}, CovariantAxis{I}}, + SimpleSymmetric{FT, S}, + } @inline function LocalGeometry( coordinates, J, @@ -30,20 +38,19 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S} ∂ξ∂x = inv(∂x∂ξ) C = typeof(coordinates) Jinv = inv(J) - return new{I, C, FT, S}( - coordinates, - J, - WJ, - Jinv, - ∂x∂ξ, - ∂ξ∂x, - ∂ξ∂x * ∂ξ∂x', - ∂x∂ξ' * ∂x∂ξ, + gᵢⱼ₀ = ∂x∂ξ' * ∂x∂ξ + gⁱʲ = ∂ξ∂x * ∂ξ∂x' + @assert LinearAlgebra.issymmetric(components(gᵢⱼ₀)) + @assert LinearAlgebra.issymmetric(components(gⁱʲ)) + ˢgᵢⱼ = SimpleSymmetric(components(gᵢⱼ₀)) + gᵢⱼ = AxisTensor{FT, 2, typeof(axes(gᵢⱼ₀)), typeof(ˢgᵢⱼ)}( + axes(gᵢⱼ₀), + ˢgᵢⱼ, ) + return new{I, C, FT, S}(coordinates, J, WJ, Jinv, ∂x∂ξ, ∂ξ∂x, gⁱʲ, gᵢⱼ) end end - """ SurfaceGeometry diff --git a/test/Fields/unit_field.jl b/test/Fields/unit_field.jl index edb93f87e1..896eb30f01 100644 --- a/test/Fields/unit_field.jl +++ b/test/Fields/unit_field.jl @@ -708,16 +708,15 @@ end Geometry.Cartesian123Point(x1, x2, x3), ] all_components = [ - SMatrix{1, 1, FT}(range(1, 1)...), - SMatrix{2, 2, FT}(range(1, 4)...), - SMatrix{3, 3, FT}(range(1, 9)...), - SMatrix{3, 3, FT}(range(1, 9)...), - SMatrix{1, 1, FT}(range(1, 1)...), - SMatrix{2, 2, FT}(range(1, 4)...), - SMatrix{3, 3, FT}(range(1, 9)...), + SMatrix{1, 1}(FT[1]), + SMatrix{2, 2}(FT[1 2; 3 4]), + SMatrix{3, 3}(FT[1 2 10; 4 5 6; 7 8 9]), + SMatrix{3, 3}(FT[1 2 10; 4 5 6; 7 8 9]), + SMatrix{2, 2}(FT[1 2; 3 4]), + SMatrix{3, 3}(FT[1 2 10; 4 5 6; 7 8 9]), ] - expected_dzs = [1.0, 4.0, 9.0, 9.0, 1.0, 4.0, 9.0] + expected_dzs = [1.0, 4.0, 9.0, 9.0, 1.0, 2.0, 9.0] for (components, coord, expected_dz) in zip(all_components, coords, expected_dzs) diff --git a/test/Geometry/axistensor_conversion_benchmarks.jl b/test/Geometry/axistensor_conversion_benchmarks.jl index 9188561363..81abbb61e5 100644 --- a/test/Geometry/axistensor_conversion_benchmarks.jl +++ b/test/Geometry/axistensor_conversion_benchmarks.jl @@ -1,3 +1,7 @@ +#= +julia --project +using Revise; include(joinpath("test", "Geometry", "axistensor_conversion_benchmarks.jl")) +=# using Test, StaticArrays #! format: off import Random, BenchmarkTools, StatsBase,