Skip to content

Commit

Permalink
Merge pull request #24 from SciML/ChrisRackauckas-patch-1
Browse files Browse the repository at this point in the history
Missing name in NueralNetworkBlock
  • Loading branch information
ChrisRackauckas authored Apr 21, 2024
2 parents 2d6db67 + 7ad050c commit 4b37aba
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/src/friction.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ chain = Lux.Chain(
Lux.Dense(10 => 10, Lux.mish, use_bias = false),
Lux.Dense(10 => 1, use_bias = false)
)
nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
@named nn = NeuralNetworkBlock(1, 1; chain = chain, rng = StableRNG(1111))
eqs = [connect(model.nn_in, nn.output)
connect(model.nn_out, nn.input)]
Expand Down
10 changes: 6 additions & 4 deletions src/ModelingToolkitNeuralNets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ include("utils.jl")
chain = multi_layer_feed_forward(n_input, n_output),
rng = Xoshiro(0),
init_params = Lux.initialparameters(rng, chain),
eltype = Float64)
eltype = Float64,
name)
Create an `ODESystem` with a neural network inside.
"""
Expand All @@ -26,7 +27,8 @@ function NeuralNetworkBlock(n_input = 1,
chain = multi_layer_feed_forward(n_input, n_output),
rng = Xoshiro(0),
init_params = Lux.initialparameters(rng, chain),
eltype = Float64)
eltype = Float64,
name)
ca = ComponentArray{eltype}(init_params)

@parameters p[1:length(ca)] = Vector(ca)
Expand All @@ -39,8 +41,8 @@ function NeuralNetworkBlock(n_input = 1,

eqs = [output.u ~ out]

@named ude_comp = ODESystem(
eqs, t_nounits, [], [p, T], systems = [input, output])
ude_comp = ODESystem(
eqs, t_nounits, [], [p, T]; systems = [input, output], name)
return ude_comp
end

Expand Down
2 changes: 1 addition & 1 deletion test/lotka_volterra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ end
model = lotka_ude()

chain = multi_layer_feed_forward(2, 2)
nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
@named nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))

eqs = [connect(model.nn_in, nn.output)
connect(model.nn_out, nn.input)]
Expand Down

0 comments on commit 4b37aba

Please sign in to comment.