Today I'd like to present you a simple deeplearning4j integration example and a rest API using SparkJava framework. SparkJava is a very simple framework for building REST API and is supported by AWS Lambda if deployment to this service provider is needed.
tl;dr Source code is available github
In order to train neural network I'm going to use one of dl4j examples - namely LeNet MNIST. It contains everything we need for this demo - both network topology as well as required training and validation data.
After going to deeplearning4j examples, you may find that there are 3 LeNet MNIST examples - I don't know why this happened but I decided to use the one with more recent API: https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/LenetMnistExample.java. So please checkout examples to your hard drive and start training! Everything should work out-of-the box, which is really great! But it doesn't save trained model so adding following line fixes it
Below you can see video from my training.
Having a trained model in examples folder is first step. Now I'll show you how to create simple REST API using SparkJava framework. This is a simple example so there's a little shortcut: serving both front and server-side from the same module. Normally I store front-end in S3 and connect to my backed but the way it's done in this writeup is easier to follow. First we need to load our model and define endpoint for application. The endpoint will support http POST requests in order to receive image uploaded by user, then store this file in a local directory and feed it to neural network. As a result of this operation client will receive small JSON with probability of a predicted number and proper label describing it.
First step we'll define required dependencies for both SparkJava and deeplearning4j
Although this might seem like an overkill - this is the place where java ecosystem, with its maven repositories really shines. All dependencies have other dependencies with clearly defined versions (called transitive dependencies), so anytime you'd like to run this application you'll be able to do so. In the repository you'll find some more dependencies as I prefer log4j2 and the rest are adapters for different logging frameworks, as well as specific java version along with dl4j versions. This setup gurantees all logging is now controlled through log4j2.xml
So our environment for developing application is ready and now we can start defining code responsible for our REST endpoint.
First we need to read our model into memory so it can be used for recognizing images. I achieved this by reading file from src/main/resources/model directory.
Next we need to make sure that we have a temporary storage for uploaded file:
And finally we can add definition of an endpoint
Once we have the file we can start processing it - I used ImageIO static call to read the file into memory. It's converted into INDArray which is feed to the network
Believe it or not but it's that simple - you only need to convert file into format understood by neural network and boom - you're all set!
NOTE: Remember that this is just initial version of this endpoint and currently it handles only images from MNIST dataset itself.
Conversion from BufferedImage to INDArray is very simple but I want to share it as there are at least 2 things worth mentioning
There's still one piece of the puzzle missing - how to interpret output of NN so we can respond with proper JSON? Our NN output is an INDArray vector with 10 entries which are predictions. So from there it's really simple - using INDArray interface first I want to get maximum value of predictions (using maxNumber method), than find a corresponding index which is the number and send these as response.
NOTE: Yes I know it's not the way it should be handled but I want this demo to be as short as possible.
And there it is - we have initial version of our Rest endpoint.
Next we need client-side application for feeding our REST endpoint with images. The easiest way is to use HTML canvas which allows drawing in browser. This makes very easy for users to start working with - just got a URL, draw your digit and receive result.
So let's start with HTML canvas - it's a HTML element introduced in HTML5 that probably everyone has heard of but for real drawing to happen we need to use 2d drawing api and it might be a little tricky to write it from scratch so I decided to use Sketchapd. It looked pretty lightweight and had all the functionalities I need. In constructor you tell it width and height of your canvas and its query selector:
After initializing Sketchpad object I decided to rely on mouseup event so image is sent automatically to server when user stops writing for at least 700 ms (achieved with debounced function copied from great article by Hajime Yamasaki Vukelic explaining differnce between debounce and throttle - really recommend reading if you're not sure what they are). Unfortunately I had to use some boilerplate code in order to convert canvas content into Blob that can be easily fed into fetch method for posting.
But it all worked great - you can try it yourself. There's just one small issue - my digits are black written on white background. Which is the opposite to what the MNIST dataset images are. We deal with it in the next section.
With all those things in place, we can already connect our client with our backend but there's still one thing left - MNIST dataset contains black images with white digits while canvas demo contains white background and black digits. So we need to to data normalization so it's understood by our network. Getting negative from image is very simple - just subtract current pixel value from 255 and there you go :)
You may be wondering why I used awt and BufferedImage and not some more sophisticated mechanism? My idea here is to leave as much processing on Java side as possible because maintaining different technologies in the same project is usually tough. Yes I know that probably image operations might be delegated to libraries like JavaCV but still the trouble of learning new API only for the sake of simple image manipulation seems like an overkill to me.
That's it - as you can see the real benefit of having java-based DL library makes it really easy to integrate with whatever Java library/framework you'd like, leaving us in our safe statically-typed world...
Source code is available github