Skip to content

Commit

Permalink
Define coeff function
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenszhu committed Sep 16, 2024
1 parent b88fb66 commit a1d27e1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 30 deletions.
4 changes: 2 additions & 2 deletions src/inspect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ function AbstractTrees.nodevalue(x::BasicSymbolic)
string(x.impl.val)
elseif isadd(x)
string(exprtype(x),
(scalar = x.impl.coeff, coeffs = Tuple(k => v for (k, v) in x.impl.dict)))
(scalar = coeff(x), coeffs = Tuple(k => v for (k, v) in x.impl.dict)))
elseif ismul(x)
string(exprtype(x),
(scalar = x.impl.coeff, powers = Tuple(k => v for (k, v) in x.impl.dict)))
(scalar = coeff(x), powers = Tuple(k => v for (k, v) in x.impl.dict)))
elseif isdiv(x) || ispow(x)
string(exprtype(x))
else
Expand Down
2 changes: 1 addition & 1 deletion src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ function quick_mulpow(x, y)
den = _Pow(symtype(y), y.impl.base, y.impl.exp-d[y.impl.base])
delete!(d, y.impl.base)
end
return _Mul(symtype(x), x.impl.coeff, d), den
return _Mul(symtype(x), coeff(x), d), den
else
return x, y
end
Expand Down
40 changes: 22 additions & 18 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ function name(x::BasicSymbolic)
x.impl.name
end

function coeff(x::BasicSymbolic)
x.impl.coeff
end

# Same but different error messages
@noinline error_on_type() = error("Internal error: unreachable reached!")
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
Expand Down Expand Up @@ -293,7 +297,7 @@ function _isequal(a, b, E)
if E === SYM
nameof(a) === nameof(b)
elseif E === ADD || E === MUL
coeff_isequal(a.impl.coeff, b.impl.coeff) && isequal(a.impl.dict, b.impl.dict)
coeff_isequal(coeff(a), coeff(b)) && isequal(a.impl.dict, b.impl.dict)
elseif E === DIV
isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den)
elseif E === POW
Expand Down Expand Up @@ -337,7 +341,7 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
h = s.hash[]
!iszero(h) && return h
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
h′ = hash(hashoffset, hash(s.impl.coeff, hash(s.impl.dict, salt)))
h′ = hash(hashoffset, hash(coeff(s), hash(s.impl.dict, salt)))
s.hash[] = h′
return h′
elseif E === DIV
Expand Down Expand Up @@ -444,7 +448,7 @@ const Rat = Union{Rational, Integer}

function ratcoeff(x)
if ismul(x)
ratcoeff(x.impl.coeff)
ratcoeff(coeff(x))
elseif x isa Rat
(true, x)
else
Expand All @@ -455,7 +459,7 @@ ratio(x::Integer,y::Integer) = iszero(rem(x,y)) ? div(x,y) : x//y
ratio(x::Rat,y::Rat) = x//y
function maybe_intcoeff(x)
if ismul(x)
coeff = x.impl.coeff
coeff = coeff(x)
if coeff isa Rational && isone(denominator(coeff))
_Mul(symtype(x), coeff.num, x.impl.dict; metadata = x.metadata)
else
Expand Down Expand Up @@ -537,7 +541,7 @@ function toterm(t::BasicSymbolic{T}) where {T}
return t
elseif E === ADD || E === MUL
args = BasicSymbolic[]
push!(args, t.impl.coeff)
push!(args, coeff(t))
for (k, coeff) in t.impl.dict
push!(
args, coeff == 1 ? k : _Term(T, E === MUL ? (^) : (*), [_Const(coeff), k]))
Expand All @@ -562,7 +566,7 @@ function makeadd(sign, coeff, xs...)
d = Dict{BasicSymbolic, Any}()
for x in xs
if isadd(x)
coeff += x.impl.coeff
coeff += coeff(x)
_merge!(+, d, x.impl.dict, filter = _iszero)
continue
end
Expand All @@ -572,7 +576,7 @@ function makeadd(sign, coeff, xs...)
end
if ismul(x)
k = _Mul(symtype(x), 1, x.impl.dict)
v = sign * x.impl.coeff + get(d, k, 0)
v = sign * coeff(x) + get(d, k, 0)
else
k = x
v = sign + get(d, x, 0)
Expand All @@ -593,7 +597,7 @@ function makemul(coeff, xs...; d = Dict{BasicSymbolic, Any}())
elseif x isa Number
coeff *= x
elseif ismul(x)
coeff *= x.impl.coeff
coeff *= coeff(x)
_merge!(+, d, x.impl.dict, filter = _iszero)
else
v = 1 + get(d, x, 0)
Expand Down Expand Up @@ -1219,10 +1223,10 @@ function +(a::SN, b::SN)
!issafecanon(+, a, b) && return term(+, a, b) # Don't flatten if args have metadata
if isadd(a) && isadd(b)
return _Add(
add_t(a, b), a.impl.coeff + b.impl.coeff, _merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
add_t(a, b), coeff(a) + coeff(b), _merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
elseif isadd(a)
coeff, dict = makeadd(1, 0, b)
return _Add(add_t(a, b), a.impl.coeff + coeff, _merge(+, a.impl.dict, dict, filter = _iszero))
return _Add(add_t(a, b), coeff(a) + coeff, _merge(+, a.impl.dict, dict, filter = _iszero))
elseif isadd(b)
return b + a
end
Expand All @@ -1236,7 +1240,7 @@ function +(a::Number, b::SN)
!issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata
iszero(a) && return b
if isadd(b)
_Add(add_t(a, b), a + b.impl.coeff, b.impl.dict)
_Add(add_t(a, b), a + coeff(b), b.impl.dict)
else
_Add(add_t(a, b), makeadd(1, a, b)...)
end
Expand All @@ -1254,15 +1258,15 @@ function -(a::SN)
return term(-, a)
end
if isadd(a)
_Add(sub_t(a), -a.impl.coeff, mapvalues((_, v) -> -v, a.impl.dict))
_Add(sub_t(a), -coeff(a), mapvalues((_, v) -> -v, a.impl.dict))
else
_Add(sub_t(a), makeadd(-1, 0, a)...)
end
end
function -(a::SN, b::SN)
(!issafecanon(+, a) || !issafecanon(*, b)) && return term(-, a, b)
if isadd(a) && isadd(b)
_Add(sub_t(a, b), a.impl.coeff - b.impl.coeff, _merge(-, a.impl.dict, b.impl.dict, filter = _iszero))
_Add(sub_t(a, b), coeff(a) - coeff(b), _merge(-, a.impl.dict, b.impl.dict, filter = _iszero))
else
a + (-b)
end
Expand All @@ -1289,16 +1293,16 @@ function *(a::SN, b::SN)
elseif isdiv(b)
_Div(a * b.impl.num, b.impl.den)
elseif ismul(a) && ismul(b)
_Mul(mul_t(a, b), a.impl.coeff * b.impl.coeff,
_Mul(mul_t(a, b), coeff(a) * coeff(b),
_merge(+, a.impl.dict, b.impl.dict, filter = _iszero))
elseif ismul(a) && ispow(b)
if b.impl.exp isa Number
_Mul(mul_t(a, b),
a.impl.coeff,
coeff(a),
_merge(+, a.impl.dict, Base.ImmutableDict(b.impl.base => b.impl.exp),
filter = _iszero))
else
_Mul(mul_t(a, b), a.impl.coeff,
_Mul(mul_t(a, b), coeff(a),
_merge(+, a.impl.dict, Base.ImmutableDict(b => 1), filter = _iszero))
end
elseif ispow(a) && ismul(b)
Expand All @@ -1321,7 +1325,7 @@ function *(a::Number, b::SN)
elseif isone(-a) && isadd(b)
# -1(a+b) -> -a - b
T = promote_symtype(+, typeof(a), symtype(b))
_Add(T, b.impl.coeff * a,
_Add(T, coeff(b) * a,
Dict{BasicSymbolic, Any}(k => v * a for (k, v) in b.impl.dict))
else
_Mul(mul_t(a, b), makemul(a, b)...)
Expand All @@ -1346,7 +1350,7 @@ function ^(a::SN, b)
elseif b isa Number && b < 0
_Div(1, a^(-b))
elseif ismul(a) && b isa Number
coeff = unstable_pow(a.impl.coeff, b)
coeff = unstable_pow(coeff(a), b)
_Mul(promote_symtype(^, symtype(a), symtype(b)),
coeff, mapvalues((k, v) -> b * v, a.impl.dict))
else
Expand Down
18 changes: 9 additions & 9 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,18 +344,18 @@ end
@testset "div" begin
@syms x::SafeReal y::Real
@test issym((2x / 2y).impl.num)
@test (2x / 3y).impl.num.impl.coeff == 2
@test (2x / 3y).impl.den.impl.coeff == 3
@test (2x / -3x).impl.num.impl.coeff == -2
@test (2x / -3x).impl.den.impl.coeff == 3
@test (2.5x / 3x).impl.num.impl.coeff == 2.5
@test (2.5x / 3x).impl.den.impl.coeff == 3
@test (x / 3x).impl.den.impl.coeff == 3
@test coeff((2x / 3y).impl.num) == 2
@test coeff((2x / 3y).impl.den) == 3
@test coeff((2x / -3x).impl.num) == -2
@test coeff((2x / -3x).impl.den) == 3
@test coeff((2.5x / 3x).impl.num) == 2.5
@test coeff((2.5x / 3x).impl.den) == 3
@test coeff((x / 3x).impl.den) == 3

@syms x y
@test issym((2x / 2y).impl.num)
@test (2x / 3y).impl.num.impl.coeff == 2
@test (2x / 3y).impl.den.impl.coeff == 3
@test coeff((2x / 3y).impl.num) == 2
@test coeff((2x / 3y).impl.den) == 3
@test (2x / -3x) == -2 // 3
@test (2.5x / 3x).impl.num == 2.5
@test (2.5x / 3x).impl.den == 3
Expand Down

0 comments on commit a1d27e1

Please sign in to comment.