Scale Keras 2.3 Applications#

../../../_images/colab_logo_32px.pngRun in Google Colab  ../../../_images/GitHub-Mark-32px.pngView source on GitHub

In this guide we will describe how to scale out Keras 2.3 programs using Orca in 4 simple steps.

Step 0: Prepare Environment#

We recommend using conda to prepare the environment. Please refer to the install guide for more details.

conda create -n py37 python=3.7  # "py37" is conda environment name, you can use any name you like.
conda activate py37
pip install bigdl-orca
pip install tensorflow==1.15.0
pip install tensorflow-datasets==2.1.0
pip install psutil
pip install pandas
pip install scikit-learn

Step 1: Init Orca Context#

from bigdl.orca import init_orca_context, stop_orca_context

if cluster_mode == "local":  # For local machine
    init_orca_context(cluster_mode="local", cores=4, memory="10g")
    dataset_dir = "~/tensorflow_datasets"
elif cluster_mode == "k8s":  # For K8s cluster
    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
elif cluster_mode == "yarn":  # For Hadoop/YARN cluster
    init_orca_context(cluster_mode="yarn", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
    dataset_dir = "hdfs:///tensorflow_datasets"

This is the only place where you need to specify local or distributed mode. View Orca Context for more details.

Note: You should export HADOOP_CONF_DIR=/path/to/hadoop/conf/dir when running on Hadoop YARN cluster. View Hadoop User Guide for more details. To use tensorflow_datasets on HDFS, you should correctly set HADOOP_HOME, HADOOP_HDFS_HOME, LD_LIBRARY_PATH, etc. For more details, please refer to TensorFlow documentation link.

Step 2: Define the Model#

You may define your model, loss and metrics in the same way as in any standard (single node) Keras program.

from tensorflow import keras

model = keras.Sequential(
    [keras.layers.Conv2D(20, kernel_size=(5, 5), strides=(1, 1), activation='tanh',
                         input_shape=(28, 28, 1), padding='valid'),
     keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
     keras.layers.Conv2D(50, kernel_size=(5, 5), strides=(1, 1), activation='tanh',
     keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
     keras.layers.Dense(500, activation='tanh'),
     keras.layers.Dense(10, activation='softmax'),


Step 3: Define Train Dataset#

You can define the dataset using standard Orca also supports Spark DataFrame and Orca XShards.

import tensorflow as tf
import tensorflow_datasets as tfds

def preprocess(data):
    data['image'] = tf.cast(data["image"], tf.float32) / 255.
    return data['image'], data['label']

# get DataSet
mnist_train = tfds.load(name="mnist", split="train", data_dir=dataset_dir)
mnist_test = tfds.load(name="mnist", split="test", data_dir=dataset_dir)

mnist_train =
mnist_test =

Step 4: Fit with Orca Estimator#

First, create an Estimator.

from import Estimator

est = Estimator.from_keras(keras_model=model)

Next, fit and evaluate using the Estimator.,

result = est.evaluate(mnist_test)

That’s it, the same code can run seamlessly in your local laptop and the distribute K8s or Hadoop cluster.

Note: You should call stop_orca_context() when your program finishes.