Scale TensorFlow 2 Applications#

../../../_images/colab_logo_32px.pngRun in Google Colab  ../../../_images/GitHub-Mark-32px.pngView source on GitHub

In this guide we will describe how to to scale out TensorFlow 2 programs using Orca in 4 simple steps.

Step 0: Prepare Environment#

We recommend using conda to prepare the environment. Please refer to the install guide for more details.

conda create -n py37 python=3.7  # "py37" is conda environment name, you can use any name you like.
conda activate py37

pip install bigdl-orca[ray]
pip install tensorflow

Step 1: Init Orca Context#

from bigdl.orca import init_orca_context, stop_orca_context

if cluster_mode == "local":  # For local machine
    init_orca_context(cluster_mode="local", cores=4, memory="4g")
elif cluster_mode == "k8s":  # For K8s cluster
    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2, memory="4g")
elif cluster_mode == "yarn":  # For Hadoop/YARN cluster
    init_orca_context(cluster_mode="yarn", num_nodes=2, cores=2, memory="4g")

This is the only place where you need to specify local or distributed mode. View Orca Context for more details.

Please check the tutorials if you want to run on Kubernetes or Hadoop/YARN clusters.

Step 2: Define the Model#

You can then define the Keras model in the Creator Function using the standard TensorFlow 2 Keras APIs.

import tensorflow as tf

def model_creator(config):
    model = tf.keras.Sequential(
        [tf.keras.layers.Conv2D(20, kernel_size=(5, 5), strides=(1, 1), activation='tanh',
                                input_shape=(28, 28, 1), padding='valid'),
         tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
         tf.keras.layers.Conv2D(50, kernel_size=(5, 5), strides=(1, 1), activation='tanh',
         tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
         tf.keras.layers.Dense(500, activation='tanh'),
         tf.keras.layers.Dense(10, activation='softmax'),

    return model

Step 3: Define the Dataset#

You can define the dataset in the Creator Function using standard APIs. Orca also supports Spark DataFrame and Orca XShards.

def preprocess(x, y):
    x = tf.cast(tf.reshape(x, (28, 28, 1)), dtype=tf.float32) / 255.0
    return x, y

def train_data_creator(config, batch_size):
    (train_feature, train_label), _ = tf.keras.datasets.mnist.load_data()

    dataset =, train_label))
    dataset = dataset.repeat()
    dataset =
    dataset = dataset.shuffle(1000)
    dataset = dataset.batch(batch_size)
    return dataset

def val_data_creator(config, batch_size):
    _, (val_feature, val_label) = tf.keras.datasets.mnist.load_data()

    dataset =, val_label))
    dataset = dataset.repeat()
    dataset =
    dataset = dataset.batch(batch_size)
    return dataset

Step 4: Fit with Orca Estimator#

First, create an Orca Estimator for TensorFlow 2.

from bigdl.orca.learn.tf2 import Estimator

est = Estimator.from_keras(model_creator=model_creator, workers_per_node=2)

Next, fit and evaluate using the Estimator.

batch_size = 320
stats =,
                steps_per_epoch=60000 // batch_size,
                validation_steps=10000 // batch_size)

stats = est.evaluate(val_data_creator, num_steps=10000 // batch_size)


Step 5: Save and Load the Model#

Orca TensorFlow 2 Estimator supports two formats to save and load the entire model (TensorFlow SavedModel and Keras H5 Format). The recommended format is SavedModel, which is the default format when you use

You could also save the model to Keras H5 format by passing save_format='h5' or a filename that ends in .h5 or .keras to

Note that if you run on Apache Hadoop/YARN cluster, you are recommended to save the model to HDFS and load from HDFS as well.

1. SavedModel Format

# save model in SavedModel format"lenet_model")

# load model

2. HDF5 format

# save model in H5 format"lenet_model.h5", save_format='h5')

# load model

That’s it, the same code can run seamlessly on your local laptop and scale to Kubernetes or Hadoop/YARN clusters.

Note: You should call stop_orca_context() when your program finishes.