Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't load the LSTM model pre-trained in Python to R #1439

Closed
gjfreitas opened this issue May 3, 2024 · 1 comment
Closed

Can't load the LSTM model pre-trained in Python to R #1439

gjfreitas opened this issue May 3, 2024 · 1 comment

Comments

@gjfreitas
Copy link

gjfreitas commented May 3, 2024

I have the following model architecture in Python

# Define the LSTM model
model = Sequential()
model.add(LSTM(256, input_shape=(X_train.shape[1], X_train.shape[2])))
model.add(Dropout(0.2))  # Add dropout to prevent overfitting
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))  # Add dropout to prevent overfitting
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.2))  # Add dropout to prevent overfitting
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.2))  # Add dropout to prevent overfitting
model.add(Dense(2))  # Output layer with 2 neurons for p and q
#model.compile(optimizer=RMSprop(), loss='mae')
model.compile(optimizer='adam', loss='mse')

# Define early stopping to prevent overfitting
early_stopping = EarlyStopping(monitor='val_loss', patience=10)

# Train the model
history = model.fit(X_train, y_train, epochs=100, batch_size=16, validation_data=(X_test, y_test), verbose=1, callbacks=[early_stopping])

I have tried saving the model after training in .keras, .h5, .json and then import into R to test there.

Nothing seems to work.

Anyone got an idea on how to do it?

This is one of the errors I'm getting

Error in py_call_impl(callable, call_args$unnamed, call_args$named) : 
  TypeError: Could not locate function 'mae'. Make sure custom classes are decorated with `@keras.saving.register_keras_serializable()`. Full object config: {'module': 'keras.metrics', 'class_name': 'function', 'config': 'mae', 'registered_name': 'mae'}
Run `reticulate::py_last_error()` for details.
@t-kalinowski
Copy link
Member

Hi, can you please post a reproducible example? Something self-contained I can run locally. Please add necessary import statements to the Python script, define X_train = np.ones((3, 4, 5)), call model.save(). Similarly, how are you loading in R to see the error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants