How to bring your neural network to the web

Feb 10, 2020
Last update: Nov 19, 2020
~ 10 min

Artificial intelligence, neural networks, machine learning. I don’t know which of them is the bigger buzzword. If we look past the hype there are some actually very interesting use cases for machine learning in the browser.

For the lazy that simply what to just to the source code
Here is the git repo for you 🙂
Or simply go to the finished website

Today we will look on how to train a simple mnist digit recogniser and then export it into a website where we then can see it in action. Therefore this article will be split into three parts

  1. Training
  2. Export & import the pre-trained model into a website
  3. Build a simple website where we can use the model.

Also I am not going to explain what machine learning is, as there are enough guides, videos, podcasts, … that already do a much better job than I could and would be outside the scope of this article.

Photo by Natasha Connell on Unsplash

So the first thing we need to understand is that we will not train the model in the browser. That is a job for GPUs and the goal here is only to use a pre-trained model inside of the browser. Training is a much more resource intensive task than simply using the net.

Training the model

So, the first step is to actually have a model. I will do this in tensorflow 2.0 using the now included keras api. This means Python 🎉

The code below is basically an adapted version of the keras hello world example.
If you want to run the code yourself (which you should!) simply head over to Google Colab, create a new file and just paste the code. There you can run it for free on GPUs which is pretty dope!

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.utils import to_categorical

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Reshaping for channels_last (tensorflow) with one channel
size = 28
print(x_train.shape, x_test.shape)
x_train = x_train.reshape(len(x_train), size, size, 1).astype('float32')
x_test = x_test.reshape(len(x_test), size, size, 1).astype('float32')
print(x_train.shape, x_test.shape)

# Normalize
upper = max(x_train.max(), x_test.max())
lower = min(x_train.min(), x_test.min())
print(f'Max: {upper} Min: {lower}')
x_train /= upper
x_test /= upper

total_classes = 10
y_train = to_categorical(y_train, total_classes)
y_test = to_categorical(y_test, total_classes)

# Make the model
model = Sequential()
model.add(Conv2D(64, (3, 3), activation='relu', input_shape=(size,size, 1), data_format='channels_last'))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dense(128, activation='relu'))
model.add(Dense(total_classes, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train, y_train,

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

We can run this and we will get a pretty good accuracy. The MNIST dataset ist not very hard to train.

Export the model

Now the conventional way to save a model is to use the"model.h5") method provided by keras. This uses the h5 file format.
Unfortunately this is not compatible with tensorflow-js. So we need another way.

There is a package called tensorflowjs for python (confusing right? 😅) that provides the functionality we need

import tensorflowjs as tfjs

tfjs.converters.save_keras_model(model, './js')

It save the model data inside the ./js folder ready to be used.
Inside there you will find a model.json that basically describes the structure of the model and something like group1-shard1of1.bin that contains the fitted weights.

Import the model

Now we are ready to import that. First we need to install the @tensorflow/tfjs package.

import * as tf from '@tensorflow/tfjs';

let model

tf.loadLayersModel('/model.json').then(m => {
    model = m

Ok how do I use that now?

const tensor = tf.tensor(new Uint8Array(ourData), [1, 28, 28, 1])
const prediction = model.predict(tensor)

What is happening here?
In order to predict a value we first need a tensor (vector) the same shape as our original input with which we trained the model with. In our case that is 1x28x28x1.
Also we will convert our pixel data into a Uint8Array.

Using the canvas element to draw and predict numbers

I’m not gonna talk about what bundler, etc. I’m using. If you interested simply have a look at the git repo.

First lets write some basic html for the skeleton of our page.


        * {
            box-sizing: border-box;
            font-family: monospace;

        body {
            padding: 0;
            margin: 0;
            height: 100vh;
            width: 100vw;
            display: flex;
            justify-content: center;
            align-items: center;

        body>div {
            text-align: center;

        div canvas {
            display: inline-block;
            border: 1px solid;

        div input {
            display: inline-block;
            margin-top: .5em;
            padding: .5em 2em;
            background: white;
            outline: none;
            border: 1px solid;
            font-weight: bold;

        <h1>MNIST (Pretrained)</h1>
        <canvas id="can" width="28" height="28"></canvas>
        <br />
        <input id="clear" type="button" value="clear">
        <br />
        <input id="test" type="button" value="test">
        <br />
        <h2 id="result"></h2>
        <a href="">
            <h3>source code</h3>

    <script src="./tf.js"></script>
    <script src="./canvas.js"></script>


Next we need come short code for drawing on a canvas.
The code is adapted from this stackoverflow answer and reduced to the only the basics we need.

In essence it’s a canvas that listens on our mouse events and fills the pixels with black. Nothing more.

/* jslint esversion: 6, asi: true */

var canvas, ctx, flag = false,
    prevX = 0,
    currX = 0,
    prevY = 0,
    currY = 0,
    dot_flag = false;

var x = "black",
    y = 2;

function init() {
    canvas = document.getElementById('can');
    ctx = canvas.getContext("2d");
    w = canvas.width;
    h = canvas.height;

    canvas.addEventListener("mousemove", function (e) {
        findxy('move', e)
    }, false);
    canvas.addEventListener("mousedown", function (e) {
        findxy('down', e)
    }, false);
    canvas.addEventListener("mouseup", function (e) {
        findxy('up', e)
    }, false);
    canvas.addEventListener("mouseout", function (e) {
        findxy('out', e)
    }, false);

    window.document.getElementById('clear').addEventListener('click', erase)

function draw() {
    ctx.moveTo(prevX, prevY);
    ctx.lineTo(currX, currY);
    ctx.strokeStyle = x;
    ctx.lineWidth = y;

function erase() {
    ctx.clearRect(0, 0, w, h);

function findxy(res, e) {
    if (res == 'down') {
        prevX = currX;
        prevY = currY;
        currX = e.clientX - canvas.offsetLeft;
        currY = e.clientY - canvas.offsetTop;

        flag = true;
        dot_flag = true;
        if (dot_flag) {
            ctx.fillStyle = x;
            ctx.fillRect(currX, currY, 2, 2);
            dot_flag = false;
    if (res == 'up' || res == "out") {
        flag = false;
    if (res == 'move') {
        if (flag) {
            prevX = currX;
            prevY = currY;
            currX = e.clientX - canvas.offsetLeft;
            currY = e.clientY - canvas.offsetTop;


And not the glue to put this together is the piece of code that listens on the “test” button.

import * as tf from '@tensorflow/tfjs';

let model

tf.loadLayersModel('/model.json').then(m => {
    model = m

window.document.getElementById('test').addEventListener('click', async () => {
    const canvas = window.document.querySelector('canvas')

    const { data, width, height } = canvas.getContext('2d').getImageData(0, 0, 28, 28)

    const tensor = tf.tensor(new Uint8Array(data.filter((_, i) => i % 4 === 3)), [1, 28, 28, 1])
    const prediction = model.predict(tensor)
    const result = await
    const guessed = result.indexOf(1)
    window.document.querySelector('#result').innerText = guessed

Here we need to explain a few things.
canvas.getContext('2d').getImageData(0, 0, 28, 28) simply returns a flattened array of the pixels from the point (0,0) to (28,28).

Then, instead of simply passing the data to the tensor. we need to do some magic with data.filter in order to get only every 3rd pixel. This is because our canvas has 3 channels + 1 alpha, but we only need to know if the pixel is black or not. We do this by simply filtering for the index mod 4

data.filter((_, i) => i % 4 === 3)

Lastly we need to interpret the result. return an array with 10 items. Because we have trained it that way that we only have 10 possible outcomes. 10 Digits right?
Well in that case we simply search in which position in the array we have a 1 and the index is out solution.
We search for a 1 because we only have floats from 0 to 1. So 1 is the maximum.

I hope this helped you understand the process better. It was pretty confusing at first for me too 😬