-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model API #18
Labels
Comments
#include <bnn/nn/models.hpp>
#include <bnn/nn/layers.hpp>
#include <bnn/metrics/losses.hpp>
#include <bnn/nn/optimisers.hpp>
#include <bnn/nn/activations.hpp>
#include <bnn/data/read.hpp>
#include <bnn/core/tensor.hpp>
#include <vector>
#include <iostream>
using namespace std;
using namespace bnn;
int main()
{
TensorCPU<float>* mnist = load_data("mnist_train_images");
TensorCPU<float>* labels = load_data("mnist_train_labels");
vector<unsigned> shape_1 = {784};
vector<unsigned> shape_2 = {200};
vector<unsigned> shape_3 = {10};
Dense<data_type>* Dense_1 = new Dense<data_type>(input_shape=shape_1, output_shape=shape_2, activation=relu);
Dense<data_type>* Dense_2 = new Dense<data_type>(shape_2, shape_2, activation=relu);
Dense<data_type>* Output = new Dense<data_type>(shape_2, shape_3, activation=softmax);
Model* mnist_nn = new Model(Dense_1, Dense_2, Output);
mnist_nn->train(input_data=mnist, targets=labels, optimiser=adam, loss=cross_entropy, batch_size=128, epochs=600);
TensorCPU<float>* mnist_test = load_data("mnist_test_images", flatten=true, normalise=true);
TensorCPU<float>* labels_test = load_data("mnist_test_labels", one_hot=true);
float test_error = mnist_nn->test(test_input=mnist_test, targets=labels_test, eval_metric=error_rate);
std::cout<<test_error<<std::endl;
} |
The above is a sample code for training and testing a classification model. It highlights some things that need to be implemented before we can think of working on model APIs. |
A few functions are to be added in
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description of the problem
APIs for models(train, test) need to be discussed in the issue. A simple use case for starting,
Image classification on MNIST dataset with 784-200-200-10.
Example of the problem
References/Other comments
The text was updated successfully, but these errors were encountered: