Skip to content

Commit

Permalink
Avoid JIT compilation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Jul 10, 2024
1 parent ceb2d34 commit bb7ebd7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,8 @@ def test_fit(feedforward, discretizer, trainable_theta):
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="adam",
metrics=["accuracy"],
# can't JIT compile the `tf.linalg.expm` operation used in this particular case
jit_compile=not (trainable_theta and discretizer == "zoh"),
)

model.fit(x_train, y_train, epochs=10, validation_split=0.2)
Expand Down Expand Up @@ -633,7 +635,12 @@ def test_theta_update(discretizer, trainable_theta, tmp_path):
lmu = keras.layers.RNN(lmu_cell)(inputs)
model = keras.Model(inputs=inputs, outputs=lmu)

model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
model.compile(
loss=keras.losses.MeanSquaredError(),
optimizer="adam",
# can't JIT compile the `tf.linalg.expm` operation used in this particular case
jit_compile=not (trainable_theta and discretizer == "zoh"),
)

# make sure theta_inv is set correctly to initial value
assert np.allclose(lmu_cell.theta_inv.numpy(), 1 / theta)
Expand Down

0 comments on commit bb7ebd7

Please sign in to comment.