Software 2.0 - deeplearning4j SparkJava integration

Today I'd like to present you a simple deeplearning4j integration example and a rest API using SparkJava framework. SparkJava is a very simple framework for building REST API and is supported by AWS Lambda if deployment to this service provider is needed.

tl;dr Source code is available github

So first let me present a scenario behind this example:

  1. User opens a html page in a browser that allows drawing.
  2. User draws digit on canvas.
  3. Digit (png image) is sent to the server using HTTP POST Request.
  4. Rest API endpoint receives this request and stores file in temporary storage.
  5. File is converted into format understood by neural network (this means both image operations as well as vectorization).
  6. Network component performs analysis.
  7. Analysis result is returned to the browser.
  8. Browser draws digit (as result of analysis on server side) alongside original image.

This application will be built in following steps:

  1. Train neural network for recognizing digits.
  2. Prepare rest api for analysing image.
  3. Prepare html canvas application for drawing and analysing results.
  4. Extend backed to accept white background images in different sizes.

Training the network (LeNet)

In order to train neural network I'm going to use one of dl4j examples - namely LeNet MNIST. It contains everything we need for this demo - both network topology as well as required training and validation data.

After going to deeplearning4j examples, you may find that there are 3 LeNet MNIST examples - I don't know why this happened but I decided to use the one with more recent API: So please checkout examples to your hard drive and start training! Everything should work out-of-the box, which is really great! But it doesn't save trained model so adding following line fixes it ModelSerializer.writeModel(model, new File("src/main/resources/models/" + "/"), true); right after this line:"****************Example finished********************"); (please make sure your models directory exists!)

Below you can see video from my training.

Prepare rest api for analysing image

Having a trained model in examples folder is first step. Now I'll show you how to create simple REST API using SparkJava framework. This is a simple example so there's a little shortcut: serving both front and server-side from the same module. Normally I store front-end in S3 and connect to my backed but the way it's done in this writeup is easier to follow. First we need to load our model and define endpoint for application. The endpoint will support http POST requests in order to receive image uploaded by user, then store this file in a local directory and feed it to neural network. As a result of this operation client will receive small JSON with probability of a predicted number and proper label describing it.

First step we'll define required dependencies for both SparkJava and deeplearning4j org.nd4j nd4j-native-platform ${nd4j.version} org.deeplearning4j deeplearning4j-core ${dl4j.version} com.sparkjava spark-core 2.8.0

Although this might seem like an overkill - this is the place where java ecosystem, with its maven repositories really shines. All dependencies have other dependencies with clearly defined versions (called transitive dependencies), so anytime you'd like to run this application you'll be able to do so. In the repository you'll find some more dependencies as I prefer log4j2 and the rest are adapters for different logging frameworks, as well as specific java version along with dl4j versions. This setup gurantees all logging is now controlled through log4j2.xml

So our environment for developing application is ready and now we can start defining code responsible for our REST endpoint.

First we need to read our model into memory so it can be used for recognizing images. I achieved this by reading file from src/main/resources/model directory. public static void main(String[] args) throws Exception { MultiLayerNetwork net = MultiLayerNetwork.load(new File("src/main/resources/models/"), false); } NOTE: As you can see model location is exactly the same as in the dl4j examples just for simplicity - but since it's just a zip file it can be stored anywhere .

Next we need to make sure that we have a temporary storage for uploaded file: File uploadDir = new File("upload"); uploadDir.mkdir(); staticFiles.externalLocation("upload");

And finally we can add definition of an endpoint post("/mnist", (req, res) -> { res.type("application/json"); }); So far our endpoint is pretty simple - it only sets type of response to application/json but since we're going to accept files we need to make sure it can be fully uploaded and stored into temporary location for further processing. This is achieved by receiveUploadedFile method. Path tempFile = receiveUploadedFile(uploadDir, req);

Once we have the file we can start processing it - I used ImageIO static call to read the file into memory. It's converted into INDArray which is feed to the network BufferedImage inputImage =; INDArray digit = toINDArray(gray); INDArray output = net.output(digit);

Believe it or not but it's that simple - you only need to convert file into format understood by neural network and boom - you're all set!

NOTE: Remember that this is just initial version of this endpoint and currently it handles only images from MNIST dataset itself.

Conversion from BufferedImage to INDArray is very simple but I want to share it as there are at least 2 things worth mentioning private static INDArray toINDArray(BufferedImage gray) { INDArray digit = Nd4j.create(28, 28); for (int i = 0; i < 28; i++) { for (int j = 0; j < 28; j++) { Color c = new Color(gray.getRGB(i, j)); digit.putScalar(j, i, (c.getGreen() & 0xFF)); } } return digit.reshape(1, 1, 28, 28).divi(0xff); } First thing is the processing order - we're used to reading images by width than by height while INDArray is created height/width. You can see that order of indices in putScalar call is inverted j,i. Secondly you can see divi operation at the end which is used to normalize input into [0.0, 1.0] set of doubles (start/end inclusive).

There's still one piece of the puzzle missing - how to interpret output of NN so we can respond with proper JSON? Our NN output is an INDArray vector with 10 entries which are predictions. So from there it's really simple - using INDArray interface first I want to get maximum value of predictions (using maxNumber method), than find a corresponding index which is the number and send these as response.

NOTE: Yes I know it's not the way it should be handled but I want this demo to be as short as possible.

And there it is - we have initial version of our Rest endpoint.

Prepare html canvas application for drawing and analysing results

Next we need client-side application for feeding our REST endpoint with images. The easiest way is to use HTML canvas which allows drawing in browser. This makes very easy for users to start working with - just got a URL, draw your digit and receive result.

So let's start with HTML canvas - it's a HTML element introduced in HTML5 that probably everyone has heard of but for real drawing to happen we need to use 2d drawing api and it might be a little tricky to write it from scratch so I decided to use Sketchapd. It looked pretty lightweight and had all the functionalities I need. In constructor you tell it width and height of your canvas and its query selector: const sketchpad = new Sketchpad({ element: '#sketchpad', width: 200, height: 200 }); sketchpad.penSize = 12;

After initializing Sketchpad object I decided to rely on mouseup event so image is sent automatically to server when user stops writing for at least 700 ms (achieved with debounced function copied from great article by Hajime Yamasaki Vukelic explaining differnce between debounce and throttle - really recommend reading if you're not sure what they are). Unfortunately I had to use some boilerplate code in order to convert canvas content into Blob that can be easily fed into fetch method for posting.

But it all worked great - you can try it yourself. There's just one small issue - my digits are black written on white background. Which is the opposite to what the MNIST dataset images are. We deal with it in the next section.

Extend backed to accept white background images in different sizes

With all those things in place, we can already connect our client with our backend but there's still one thing left - MNIST dataset contains black images with white digits while canvas demo contains white background and black digits. So we need to to data normalization so it's understood by our network. Getting negative from image is very simple - just subtract current pixel value from 255 and there you go :)


You may be wondering why I used awt and BufferedImage and not some more sophisticated mechanism? My idea here is to leave as much processing on Java side as possible because maintaining different technologies in the same project is usually tough. Yes I know that probably image operations might be delegated to libraries like JavaCV but still the trouble of learning new API only for the sake of simple image manipulation seems like an overkill to me.


That's it - as you can see the real benefit of having java-based DL library makes it really easy to integrate with whatever Java library/framework you'd like, leaving us in our safe statically-typed world...

Source code is available github