View the runnable example on GitHub

Use BFloat16 Mixed Precision for TensorFlow Keras Training#

⚠️ Warning

This feature is under quick iteration, usage may be changed later.

     The following example is adapted from

Brain Floating Point Format (BFloat16) is a custom 16-bit floating point format designed for machine learning. BFloat16 is comprised of 1 sign bit, 8 exponent bits, and 7 mantissa bits. With the same number of exponent bits, BFloat16 has the same dynamic range as FP32, but requires only half the memory usage.

BFloat16 Mixed Precison combines BFloat16 and FP32 during training, which could lead to increased performance and reduced memory usage. Compared to FP16 mixed precison, BFloat16 mixed precision has better numerical stability.

BigDL-Nano provides a TensorFlow patch ( integrated with multiple optimizations. You could apply patch_tensorflow(precision='mixed_bfloat16') to easily use BFloat16 mixed precision for training.

📝 Note

Before starting your TensorFlow Keras application, it is highly recommended to run source bigdl-nano-init to set several environment variables based on your current hardware. Empirically, these variables will bring big performance increase for most TensorFlow Keras applications on training workloads.

⚠️ Warning

BigDL-Nano will enable intel’s oneDNN optimizations by default. oneDNN BFloat16 are only supported on platforms with AVX512 instruction set.

Platforms without hardware acceleration for BFloat16 could lead to bad BFloat16 training performance.

Patch TensorFlow#

To conduct BFloat16 mixed precision training, the first thing (and the only thing for most cases) is to import patch_tensorflow from BigDL-Nano, and call it with precision set to 'mixed_bfloat16':

[ ]:
from import patch_tensorflow


📝 Note

By patching TensorFlow with 'mixed_bfloat16' as precision, a global 'mixed_bfloat16' dtype policy will be set, which will be treated as the default policy for every Keras layer created after the patching.

The layer set with 'mixed_bfloat16' dtype policy will conduct computation in BFloat16, while save its variables in Float32 data format.

Build Model#

Let’s take the MNIST digits classification dataset as an example, and suppose that we would like to create a model that will be trained on it:

[ ]:
from tensorflow import keras
from tensorflow.keras import layers, Model

inputs = keras.Input(shape=(784,), name='digits')

dense1 = layers.Dense(units=64, activation='relu', name='dense_1')
x = dense1(inputs)
dense2 = layers.Dense(units=64, activation='relu', name='dense_2')
x = dense2(x)

# Note that we separate the Dense layer and the softmax layer
# and set 'float32' as the dtype policy here for the last softmax layer
x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)
print(f'Output dtype: {}')

model = Model(inputs=inputs, outputs=outputs)

📝 Note

The dtype policy 'float32' we set here will override the global 'mixed_bfloat16' policy for the last layer, aiming at a Float32 output tensor for the model.

It is suggested to override the last layer of the model to have 'float32' dtype policy, so that numeric issues caused by dtype mismatch could be avoided when the output tensor flowing to loss.

Train Model#

When conduct training using, there is nothing special you need to do for BFloat16 mixed precision training:

[ ]:
# create train/test data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

# train with fit, y_train,

test_scores = model.evaluate(x_test, y_test, verbose=2)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])

Custom training loop#

If you create a custom traing loop, you should also wrap the train/test step function with the @nano_bf16 decorator:

[ ]:
import tensorflow as tf
from import nano_bf16 # import the decorator

# create loss function, optimizer, and train/test datasets
optimizer = keras.optimizers.RMSprop()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
train_dataset = (, y_train))
test_dataset =, y_test)).batch(8192)

# define train/test step
@nano_bf16 # apply the decorator to the train_step
def train_step(x, y):
  with tf.GradientTape() as tape:
    predictions = model(x)
    loss = loss_object(y, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  return loss

@nano_bf16 # apply the decorator to the test_step
def test_step(x):
  return model(x, training=False)

# conduct the training
for epoch in range(10):
  epoch_loss_avg = tf.keras.metrics.Mean()
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
  for x, y in train_dataset:
    loss = train_step(x, y)
  for x, y in test_dataset:
    predictions = test_step(x)
    test_accuracy.update_state(y, predictions)
  print('Epoch {}: loss={}, test accuracy={}'.format(epoch, epoch_loss_avg.result(), test_accuracy.result()))

📝 Note

If you do not set 'float32' dtype policy for the last layer of the model, and thus have BFloat16 tensor as model output, @nano_bf16 could be a compensate to avoid dtype mismatch error, which casts the input tensor and numpy ndarray of the decorated train/test step to be BFloat16.

You could try to apply the @nano_bf16 decorator to other function during the custom training loop if you meet dtype mismatch error.

(Optional) Unpatch TensorFlow#

If you want to go back to Float32 training again, you could simply call the unpatch_tensorflow function:

[ ]:
from import unpatch_tensorflow


print(f"model's dtype policy is still: {}")

📝 Note

The model created after the unpatch_tensorflow function will have 'float32' as its global dtype policy. However, the model created before, under patch_tensorflow(precision='mixed_bfloat16'), will still has layers with 'mixed_bfloat16' as dtype policy.

📚 Related Readings