Writing TensorFlow code that scales

Writing TensorFlow code that scales
Photo by Jeremy Bezanger / Unsplash

Most of the time, we write and debug our code locally. After we've passed any tests, we then deploy the scripts to a remote environment. If we're fortunate, we might have access to multiple GPUs. This transition used to be a source of errors, but the support for complicated computational setups has significantly increased.

If you are mainly working with Keras' model.fit() call, you can quickly make your script distribution aware by adding three lines of code.

Before distributing workloads

Initially, our code might look like this:

After distributing workloads

To make this code distribution-ready, we have to change lines 1 to 4. For this, we create a strategy object, which is TensorFlow's way of handling computational environments. We then use it to wrap our code, as demonstrated in the following snippet:

In the first line that we add, line one, we set up our strategy object. In the example, we use a MirroredStrategy, which tells TensorFlow to replicate the model on multiple computing devices on a single machine. For other environments, we can select different strategies, as listed in the documentation.

The remaining code is mostly the same, except that we wrap the model and optimizer creation routines within the scope of the chosen strategy object. This is the second line of code that we add and is the essential modification: The scope modifies how variables are created and where to place them.

Afterwards, we can proceed as usual and call model.fit(). Then, the workload will automatically be distributed. There is nothing to do for us here, as TensorFlow handles everything internally--that's very comfortable! And it is also a huge improvement: only a couple of years ago, researchers and practitioners had to write their distribution algorithms themselves.

Now, having to instantiate the strategy object manually all the time is tiresome. However, we can automatically create the correct object with the following code:

This code is taken from Hugginface's repository and checks the computing environment before creating the distribution strategy. Depending on the setup, the returned strategy object handles TPUs, GPUs, and also mixed-precision training.

If you use this method, then all you have to change is this line:


That is all you have to modify in your code. In summary, you

  • Create a distribution strategy
  • Wrap any variable-creating routines in the strategy's scope
  • Call model.fit() as normal

These steps covered the case when you are using Keras' high-level API to do any computation. However, if you are working with custom training loops, there are more things to pay attention to. In this case, I've got you covered here.

Pascal Janetzky

Pascal Janetzky

Avid reader & computer scientist