What is artistic style transfer?

One of the most exciting developments in deep learning to come out recently is artistic style transfer, or the ability to create a new image, known as a pastiche, based on two input images: one representing the artistic style and one representing the content.

Using this technique, we can generate beautiful new artworks in a range of styles.

This codelab will walk you through the process of using an artistic style transfer neural network in an Android app in just 9 lines of code. You can also use the techniques outlined in this codelab to implement any TensorFlow network you have already trained.

What are we going to be building?

In this codelab, you're going to take an existing Android app and add a TensorFlow model to generate stylised images using the device's camera. You'll build the following skills:

  • Using TensorFlow's Android Java & native libraries in your app
  • Importing a trained TensorFlow model in an Android app
  • Performing inference in an Android app
  • Accessing specific tensors in a TensorFlow graph

What you'll need

Get the Code

There are two ways to grab the source for this codelab: either download a ZIP file containing the code, or clone it from GitHub.

ZIP Download

Click the following button to download all the code for this codelab:

Download source code

Unpack the downloaded zip file. This will unpack a root folder (tensorflow-style-transfer-android-codelab-start), which contains the base app we'll work on in this codelab, including all of the app resources.

Check out from GitHub

Check the code out from GitHub:

git clone https://github.com/googlecodelabs/tensorflow-style-transfer-android

This will create a directory containing everything you need. If you change into it you can use git checkout codelab-start and git checkout codelab-finish to switch between the start & end of the lab, respectively.

Load the code in Android Studio

Open Android Studio and select Import Project. In the file dialog you will need to navigate to android directory within the directory you downloaded or checked out in the previous step. For example, if you checked out the code into your home directory, you'll want to open $HOME/tensorflow-style-transfer-android/android.

If prompted, you should accept the suggestion to use the Gradle wrapper and decline to use Instant Run.

Once Android Studio has imported the project, use the file browser to open the StylizeActivity class. This is where we'll work - if you can load the file OK then let's move on to the next section.

What's in this app?

This app skeleton contains an Android app that takes frames from the device's camera and renders them to a view on the main activity.

UI controls

  • The first button, labelled with a number (256 by default) controls the size of the image to display (and eventually run through the style transfer network). Smaller numbers mean smaller images, which will be faster to transform, but will be lower quality. Conversely, bigger images will contain more detail but will take longer to transform.
  • The second button, labelled save, will save the current frame to your device for you to use later.
  • The thumbnails represent possible styles you can use to transform the camera feed. Each image is a slider and you can combine multiple sliders that will represent the ratios of each style you wish to apply to your camera frames. These ratios, along with the camera frame, represent the inputs into the network.

What is all this extra code?

The app code includes some helpers that are required to interface between native TensorFlow and Android Java. The details of their implementation is not important, but you should understand what they do.

StylizeActivity.onPreviewSizeChosen(...)

The app skeleton uses a custom camera fragment that will call this method once permissions have been granted and the camera is available to use.

StylizeActivity.setStyle(...)

This keeps the style sliders normalised such that their values sum to 1.0, in line with what our network is expecting.

StylizeActivity.renderDebug(...)

Provides a debug overlay when you press the volume up or down buttons on the device, including output from TensorFlow, performance metrics and the original, unstyled, image.

StylizeActivity.stylizeImage(...)

This is where we will do our work. The provided code performs some conversion between arrays of integers (provided by Android's getPixels() method) of the form [0xRRGGBB, ...] to arrays of floats [0.0, 1.0] of the form [r, g, b, r, g, b, ...].

ImageUtils.*

Provides some helpers for transforming images. The camera provides image data in YUV space (as it is the most widely supported), but the network expects RGB, so we provide helpers to convert the image. Most of these are implemented in native C++ for speed; the code is in the jni directory but for this lab is provided via the pre-built libtensorflow_demo.so binaries in the libs directory (defined as jniLibs in Android Studio). If these aren't available, the code will fall back to a Java implementation.

About this network

The network we are importing is a result of a number of important developments. The first neural style transfer paper (Gatys, et al. 2015) introduced a technique that exploits properties of convolutional image classification networks, where lower layers identify simple edges and shapes (components of style), and higher levels identify more complex content, to generate a pastiche. This technique works on any two images but is slow to execute.

A number of improvements have since been proposed, including one that makes a trade-off by pre-training networks for each style (Johnson, et al. 2016), resulting in real-time image generation.

Finally, the network we use in this lab (Dumoulin, et al. 2016) intuited that different networks representing different styles would likely be duplicating a lot of information, and proposed a single network trained on multiple styles. An interesting side-effect of this was the ability to combine styles, which we are using here.

For a more technical comparison of these networks, as well as review of others, check out Cinjon Resnick's review article.

Inside the network

The original TensorFlow code that generated this network is available on Magenta's GitHub page, specifically the stylized image transformation model (README).

Before using it in an environment with constrained resources, such as a mobile app, this model was exported and transformed to use smaller data types & remove redundant calculations. You can read more about this process in the Graph Transforms doc, and try it out in the TensorFlow for Poets II: Optimize for Mobile codelab.

The end result is the stylize_quantized.pb file, displayed below, that you will use in the app. The transformer node contains most of the graph, click through to the interactive version to expand it.

Explore this graph interactively

Add dependencies to project

To add the inference libraries and their dependencies to our project, we need to add the TensorFlow Android Inference Library and Java API, which is available in JCenter or you can build it from the TensorFlow source.

  1. Open build.gradle in Android Studio.
  2. Add the API to the project by adding it to the dependencies block within the android block (note: this is not the buildscript block).

build.gradle

dependencies {
   compile 'org.tensorflow:tensorflow-android:1.2.0-preview'
}
  1. Click the Gradle sync button to make these changes available in the IDE.

The TensorFlow Inference Interface

When running TensorFlow code, you would normally need to manage both a computational graph and a session (as covered in the Getting Started docs), however as Android developers will likely want to perform inference over a prebuilt graph, TensorFlow provides a Java interface that manages the graph and session for you: TensorFlowInferenceInterface.

If you need more control, the TensorFlow Java API provides the familiar Session and Graph objects you may know from the Python API.

The Style Transfer Network

We have included the style transfer network described in the last section in the project's assets directory, so it will be available for you to use already. You can also download it directly, or build it yourself from the Magenta project.

It may be worth opening the interactive graph viewer so you can see the nodes we will reference shortly (Hint: open the transformer node by clicking on the + icon that appears once you hover).

Explore the graph interactively

Add the inference code

  1. In StylizeActivity.java, add the following member fields, near the top of the class (e.g. right before the NUM_STYLES declaration)

StylizeActivity.java

private TensorFlowInferenceInterface inferenceInterface;

private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";

private static final String INPUT_NODE = "input";
private static final String STYLE_NODE = "style_num";
private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid";

private static final int NUM_STYLES = 26;
  1. In the same class, find the onPreviewSizeChosen method, and construct the TensorFlowInferenceInterface. We use this method for initialization as it is called once permissions have been granted to the file system & camera.

StylizeActivity.java

@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
 // anywhere in here is fine

 inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);

 // anywhere at all...
}
  1. Now find the stylizeImage method, add the code to pass our camera bitmap and chosen styles to TensorFlow and grab the output from the graph. This goes in-between the two loops.

StylizeActivity.java

private void stylizeImage(final Bitmap bitmap) {
 // Find the code marked with: TODO: Process the image in TensorFlow here.
 // Then paste the following code in at that location.
 
 // Start copying here:

 // Copy the input data into TensorFlow.
 inferenceInterface.feed(INPUT_NODE, floatValues, 
   1, bitmap.getWidth(), bitmap.getHeight(), 3);
 inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES);

 // Execute the output node's dependency sub-graph.
 inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug());

 // Copy the data from TensorFlow back into our array.
 inferenceInterface.fetch(OUTPUT_NODE, floatValues);

 // Don't copy this code, it's already in there.
 for (int i = 0; i < intValues.length; ++i) {
 // ...
}
  1. Optional: find renderDebug and add the TensorFlow status text to the debug overlay (triggered when you press the volume keys).

StylizeActivity.java

private void renderDebug(final Canvas canvas) {
 // ... provided code that does some drawing ...

 // Look for this line, but don't copy it, it's already there.
 final Vector<String> lines = new Vector<>();

 // Add these three lines right here:
 final String[] statLines = inferenceInterface.getStatString().split("\n");
 Collections.addAll(lines, statLines);
 lines.add("");

 // Don't add this line, it's already there
 lines.add("Frame: " + previewWidth + "x" + previewHeight);
 // ... more provided code for rendering the text ...
}
  1. In Android Studio, press the Run button and wait for the project to build.
  2. You should now see style transfer happening on your device!

Further Reading

Interesting Networks

Other Code Labs