TensorFlow 2 Quickstart


../../../_images/colab_logo_32px.pngRun in Google Colab  ../../../_images/GitHub-Mark-32px.pngView source on GitHub


In this guide we will describe how to to scale out TensorFlow 2 programs using Orca in 4 simple steps. (TensorFlow 1.5 and Keras 2.3 guides are also available.)

Step 0: Prepare Environment

We recommend using conda to prepare the environment. Please refer to the install guide for more details.

conda create -n py37 python=3.7  # "py37" is conda environment name, you can use any name you like.
conda activate py37
pip install bigdl-orca[ray]
pip install tensorflow

Step 1: Init Orca Context

from bigdl.orca import init_orca_context, stop_orca_context

if cluster_mode == "local":  # For local machine
    init_orca_context(cluster_mode="local", cores=4, memory="10g")
elif cluster_mode == "k8s":  # For K8s cluster
    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)
elif cluster_mode == "yarn":  # For Hadoop/YARN cluster
    init_orca_context(cluster_mode="yarn", num_nodes=2, cores=2, memory="10g", driver_memory="10g", driver_cores=1)

This is the only place where you need to specify local or distributed mode. View Orca Context for more details.

Note: You should export HADOOP_CONF_DIR=/path/to/hadoop/conf/dir when running on Hadoop YARN cluster. View Hadoop User Guide for more details.

Step 2: Define the Model

You can then define the Keras model in the Creator Function using the standard TensroFlow 2 APIs.

import tensorflow as tf

def model_creator(config):
    model = tf.keras.Sequential(
        [tf.keras.layers.Conv2D(20, kernel_size=(5, 5), strides=(1, 1), activation='tanh',
                                input_shape=(28, 28, 1), padding='valid'),
         tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
         tf.keras.layers.Conv2D(50, kernel_size=(5, 5), strides=(1, 1), activation='tanh',
                                padding='valid'),
         tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
         tf.keras.layers.Flatten(),
         tf.keras.layers.Dense(500, activation='tanh'),
         tf.keras.layers.Dense(10, activation='softmax'),
         ]
    )

    model.compile(optimizer=tf.keras.optimizers.RMSprop(),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

Step 3: Define Train Dataset

You can define the dataset in the Creator Function using standard tf.data.Dataset APIs. Orca also supports Spark DataFrame and Orca XShards.

def preprocess(x, y):
    x = tf.cast(tf.reshape(x, (28, 28, 1)), dtype=tf.float32) / 255.0
    return x, y

def train_data_creator(config, batch_size):
    (train_feature, train_label), _ = tf.keras.datasets.mnist.load_data()

    dataset = tf.data.Dataset.from_tensor_slices((train_feature, train_label))
    dataset = dataset.repeat()
    dataset = dataset.map(preprocess)
    dataset = dataset.shuffle(1000)
    dataset = dataset.batch(batch_size)
    return dataset

def val_data_creator(config, batch_size):
    _, (val_feature, val_label) = tf.keras.datasets.mnist.load_data()

    dataset = tf.data.Dataset.from_tensor_slices((val_feature, val_label))
    dataset = dataset.repeat()
    dataset = dataset.map(preprocess)
    dataset = dataset.batch(batch_size)
    return dataset

Step 4: Fit with Orca Estimator

First, create an Estimator.

from bigdl.orca.learn.tf2 import Estimator

est = Estimator.from_keras(model_creator=model_creator, workers_per_node=2)

Next, fit and evaluate using the Estimator.

batch_size = 320
stats = est.fit(train_data_creator,
                epochs=5,
                batch_size=batch_size,
                steps_per_epoch=60000 // batch_size,
                validation_data=val_data_creator,
                validation_steps=10000 // batch_size)
                
est.save("/tmp/mnist_keras.ckpt")

stats = est.evaluate(val_data_creator, num_steps=10000 // batch_size)
est.shutdown()
print(stats)

That’s it, the same code can run seamlessly in your local laptop and to distribute K8s or Hadoop cluster.

Note: You should call stop_orca_context() when your program finishes.