#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import types
import logging
import numbers
import torch
import numpy as np
from bigdl.orca.data.ray_xshards import RayXShards
from bigdl.orca.learn.pytorch.training_operator import TrainingOperator
from bigdl.orca.learn.pytorch.pytorch_ray_worker import PytorchRayWorker
from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
convert_predict_xshards_to_dataframe, update_predict_xshards, \
process_xshards_of_pandas_dataframe
from bigdl.orca.ray import RayContext
from bigdl.orca.learn.ray_estimator import Estimator as OrcaRayEstimator
from bigdl.dllib.utils.file_utils import enable_multi_fs_load, enable_multi_fs_save
import ray
from ray.exceptions import RayActorError
logger = logging.getLogger(__name__)
def check_for_failure(remote_values):
"""Checks remote values for any that returned and failed.
:param remote_values: List of object IDs representing functions
that may fail in the middle of execution. For example, running
a SGD training loop in multiple parallel actor calls.
:return Bool for success in executing given remote tasks.
"""
unfinished = remote_values
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(unfinished)
finished = ray.get(finished)
return True
except RayActorError as exc:
logger.exception(str(exc))
return False
def partition_refs_to_creator(partition_refs):
def data_creator(config, batch_size):
from bigdl.orca.data.utils import ray_partitions_get_data_label, index_data, get_size
from torch.utils.data import Dataset, DataLoader
class NDArrayDataset(Dataset):
def __init__(self, x, y):
self.x = x # features
self.y = y # labels
def __len__(self):
return get_size(self.y)
def __getitem__(self, i):
return index_data(self.x, i), index_data(self.y, i)
params = {"batch_size": batch_size, "shuffle": True}
for arg in ["shuffle", "sampler", "batch_sampler", "num_workers", "collate_fn",
"pin_memory", "drop_last", "timeout", "worker_init_fn",
"multiprocessing_context"]:
if arg in config:
params[arg] = config[arg]
data, label = ray_partitions_get_data_label(ray.get(partition_refs),
allow_tuple=False,
allow_list=False)
print("Data size on worker: ", len(label))
dataset = NDArrayDataset(data, label)
data_loader = DataLoader(dataset, **params)
return data_loader
return data_creator
[docs]class PyTorchRayEstimator(OrcaRayEstimator):
def __init__(
self,
*,
model_creator,
optimizer_creator,
loss_creator=None,
metrics=None,
scheduler_creator=None,
training_operator_cls=TrainingOperator,
initialization_hook=None,
config=None,
scheduler_step_freq="batch",
use_tqdm=False,
backend="torch_distributed",
workers_per_node=1,
sync_stats=True,
log_level=logging.INFO):
if config is not None and "batch_size" in config:
raise Exception("Please do not specify batch_size in config. Input batch_size in the"
" fit/evaluate/predict function of the estimator instead.")
# todo remove ray_ctx to run on workers
ray_ctx = RayContext.get()
if not (isinstance(model_creator, types.FunctionType) and
isinstance(optimizer_creator, types.FunctionType)): # Torch model is also callable.
raise ValueError(
"Must provide a function for both model_creator and optimizer_creator")
self.model_creator = model_creator
self.optimizer_creator = optimizer_creator
self.loss_creator = loss_creator
self.scheduler_creator = scheduler_creator
self.training_operator_cls = training_operator_cls
self.scheduler_step_freq = scheduler_step_freq
self.use_tqdm = use_tqdm
self.sync_stats = sync_stats
if not training_operator_cls and not loss_creator:
raise ValueError("If a loss_creator is not provided, you must "
"provide a custom training operator.")
self.initialization_hook = initialization_hook
self.config = {} if config is None else config
worker_config = self.config.copy()
params = dict(
model_creator=self.model_creator,
optimizer_creator=self.optimizer_creator,
loss_creator=self.loss_creator,
scheduler_creator=self.scheduler_creator,
training_operator_cls=self.training_operator_cls,
scheduler_step_freq=self.scheduler_step_freq,
use_tqdm=self.use_tqdm,
config=worker_config,
metrics=metrics,
sync_stats=sync_stats,
log_level=log_level
)
if backend == "torch_distributed":
cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node
num_nodes = ray_ctx.num_ray_nodes * workers_per_node
RemoteRunner = ray.remote(num_cpus=cores_per_node)(PytorchRayWorker)
self.remote_workers = [
RemoteRunner.remote(**params) for i in range(num_nodes)
]
ray.get([
worker.setup.remote(cores_per_node)
for i, worker in enumerate(self.remote_workers)
])
head_worker = self.remote_workers[0]
address = ray.get(head_worker.setup_address.remote())
logger.info(f"initializing pytorch process group on {address}")
ray.get([
worker.setup_torch_distribute.remote(address, i, num_nodes)
for i, worker in enumerate(self.remote_workers)
])
elif backend == "horovod":
from bigdl.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner
self.horovod_runner = HorovodRayRunner(ray_ctx,
worker_cls=PytorchRayWorker,
worker_param=params,
workers_per_node=workers_per_node)
self.remote_workers = self.horovod_runner.remote_workers
cores_per_node = self.horovod_runner.cores_per_node
ray.get([
worker.setup.remote(cores_per_node)
for i, worker in enumerate(self.remote_workers)
])
ray.get([
worker.setup_horovod.remote()
for i, worker in enumerate(self.remote_workers)
])
else:
raise Exception("Only \"torch_distributed\" and \"horovod\" are supported "
"values of backend, but got {}".format(backend))
self.num_workers = len(self.remote_workers)
[docs] def fit(self,
data,
epochs=1,
batch_size=32,
profile=False,
reduce_results=True,
info=None,
feature_cols=None,
label_cols=None,
callbacks=[]):
"""
Trains a PyTorch model given training data for several epochs.
Calls `TrainingOperator.train_epoch()` on N parallel workers simultaneously
underneath the hood.
:param data: An instance of SparkXShards, a Spark DataFrame or a function that
takes config and batch_size as argument and returns a PyTorch DataLoader for
training.
:param epochs: The number of epochs to train the model. Default is 1.
:param batch_size: The number of samples per batch for each worker. Default is 32.
The total batch size would be workers_per_node*num_nodes.
If your training data is a function, you can set batch_size to be the input
batch_size of the function for the PyTorch DataLoader.
:param profile: Boolean. Whether to return time stats for the training procedure.
Default is False.
:param reduce_results: Boolean. Whether to average all metrics across all workers into
one dict. If a metric is a non-numerical value (or nested dictionaries), one value
will be randomly selected among the workers. If False, returns a list of dicts for
all workers.
Default is True.
:param info: An optional dictionary that can be passed to the TrainingOperator for
train_epoch and train_batch.
:param feature_cols: feature column names if data is Spark DataFrame.
:param label_cols: label column names if data is Spark DataFrame.
:param callbacks: A list for all callbacks.
:return: A list of dictionary of metrics for every training epoch. If reduce_results is
False, this will return a nested list of metric dictionaries whose length will be
equal to the total number of workers.
You can also provide custom metrics by passing in a custom training_operator_cls
when creating the Estimator.
"""
from bigdl.orca.data import SparkXShards
data, _ = maybe_dataframe_to_xshards(data,
validation_data=None,
feature_cols=feature_cols,
label_cols=label_cols,
mode="fit",
num_workers=self.num_workers)
if isinstance(data, SparkXShards):
if data._get_class_name() == 'pandas.core.frame.DataFrame':
data = process_xshards_of_pandas_dataframe(data, feature_cols, label_cols)
from bigdl.orca.data.utils import process_spark_xshards
ray_xshards = process_spark_xshards(data, self.num_workers)
def transform_func(worker, partition_refs):
data_creator = partition_refs_to_creator(partition_refs)
# Should not wrap DistributedSampler on DataLoader for SparkXShards input.
return worker.train_epochs.remote(
data_creator, epochs, batch_size, profile, info, False, callbacks)
worker_stats = ray_xshards.reduce_partitions_for_actors(self.remote_workers,
transform_func)
else:
assert isinstance(data, types.FunctionType), \
"data should be either an instance of SparkXShards or a callable function, but " \
"got type: {}".format(type(data))
success, worker_stats = self._train_epochs(data,
epochs=epochs,
batch_size=batch_size,
profile=profile,
info=info,
callbacks=callbacks)
epoch_stats = list(map(list, zip(*worker_stats)))
if reduce_results:
for i in range(len(epoch_stats)):
epoch_stats[i] = self._process_stats(epoch_stats[i])
return epoch_stats
else:
return epoch_stats
[docs] def predict(self,
data,
batch_size=32,
feature_cols=None,
profile=False):
"""
Using this PyTorch model to make predictions on the data.
:param data: An instance of SparkXShards or a Spark DataFrame
:param batch_size: The number of samples per batch for each worker. Default is 32.
:param profile: Boolean. Whether to return time stats for the training procedure.
Default is False.
:param feature_cols: feature column names if data is a Spark DataFrame.
:return: A SparkXShards that contains the predictions with key "prediction" in each shard
"""
from bigdl.orca.data import SparkXShards
param = dict(
batch_size=batch_size,
profile=profile
)
from pyspark.sql import DataFrame
if isinstance(data, DataFrame):
xshards, _ = dataframe_to_xshards(data,
validation_data=None,
feature_cols=feature_cols,
label_cols=None,
mode="predict")
pred_shards = self._predict_spark_xshards(xshards, param)
result = convert_predict_xshards_to_dataframe(data, pred_shards)
elif isinstance(data, SparkXShards):
if data._get_class_name() == 'pandas.core.frame.DataFrame':
data = process_xshards_of_pandas_dataframe(data, feature_cols)
pred_shards = self._predict_spark_xshards(data, param)
result = update_predict_xshards(data, pred_shards)
else:
raise ValueError("Only xshards or Spark DataFrame is supported for predict")
return result
[docs] def evaluate(self,
data,
batch_size=32,
num_steps=None,
profile=False,
info=None,
feature_cols=None,
label_cols=None):
"""
Evaluates a PyTorch model given validation data.
Note that only accuracy for classification with zero-based label is supported by
default. You can override validate_batch in TrainingOperator for other metrics.
Calls `TrainingOperator.validate()` on N parallel workers simultaneously
underneath the hood.
:param data: An instance of SparkXShards, a Spark DataFrame or a function that
takes config and batch_size as argument and returns a PyTorch DataLoader for
validation.
:param batch_size: The number of samples per batch for each worker. Default is 32.
The total batch size would be workers_per_node*num_nodes.
If your validation data is a function, you can set batch_size to be the input
batch_size of the function for the PyTorch DataLoader.
:param num_steps: The number of batches to compute the validation results on. This
corresponds to the number of times `TrainingOperator.validate_batch` is called.
:param profile: Boolean. Whether to return time stats for the training procedure.
Default is False.
:param info: An optional dictionary that can be passed to the TrainingOperator
for validate.
:param feature_cols: feature column names if train data is Spark DataFrame.
:param label_cols: label column names if train data is Spark DataFrame.
:return: A dictionary of metrics for the given data, including validation accuracy and loss.
You can also provide custom metrics by passing in a custom training_operator_cls
when creating the Estimator.
"""
from bigdl.orca.data import SparkXShards
data, _ = maybe_dataframe_to_xshards(data,
validation_data=None,
feature_cols=feature_cols,
label_cols=label_cols,
mode="evaluate",
num_workers=self.num_workers)
if isinstance(data, SparkXShards):
if data._get_class_name() == 'pandas.core.frame.DataFrame':
data = process_xshards_of_pandas_dataframe(data, feature_cols, label_cols)
from bigdl.orca.data.utils import process_spark_xshards
ray_xshards = process_spark_xshards(data, self.num_workers)
def transform_func(worker, partition_refs):
data_creator = partition_refs_to_creator(partition_refs)
# Should not wrap DistributedSampler on DataLoader for SparkXShards input.
return worker.validate.remote(
data_creator, batch_size, num_steps, profile, info, False)
worker_stats = ray_xshards.reduce_partitions_for_actors(self.remote_workers,
transform_func)
else:
assert isinstance(data, types.FunctionType), \
"data should be either an instance of SparkXShards or a callable function, but " \
"got type: {}".format(type(data))
params = dict(data_creator=data, batch_size=batch_size, num_steps=num_steps,
profile=profile, info=info)
worker_stats = ray.get([w.validate.remote(**params) for w in self.remote_workers])
return self._process_stats(worker_stats)
[docs] def get_model(self):
"""
Returns the learned PyTorch model.
:return: The learned PyTorch model.
"""
state = self.get_state_dict()
model = self.model_creator(self.config)
model_state = state["models"][0]
model.load_state_dict(model_state)
return model.module if hasattr(model, "module") else model
[docs] @enable_multi_fs_save
def save(self, model_path):
"""
Saves the Estimator state (including model and optimizer) to the provided model_path.
:param model_path: (str) Path to save the model.
:return:
"""
state_dict = self.get_state_dict()
torch.save(state_dict, model_path)
return model_path
[docs] @enable_multi_fs_load
def load(self, model_path):
"""
Loads the Estimator state (including model and optimizer) from the provided model_path.
:param model_path: (str) Path to the existing model.
"""
state_dict = torch.load(model_path)
self.load_state_dict(state_dict)
[docs] def save_checkpoint(self, model_path):
"""
Manually saves the Estimator state (including model and optimizer) to the provided
model_path.
:param model_path: (str) Path to save the model. Both local and remote path are supported.
e.g. "/tmp/estimator.ckpt" or "hdfs:///tmp/estimator.ckpt"
:return: None
"""
from bigdl.dllib.utils.file_utils import is_local_path
if is_local_path(model_path):
self.save(model_path)
else:
results = [
worker.save_checkpoint.remote(model_path)
for worker in self.remote_workers
]
ray.get(results)
[docs] def load_checkpoint(self, model_path):
"""
Loads the Estimator state (including model and optimizer) from the provided model_path.
:param model_path: (str) Path to the existing model. Both local and remote path are
supported. e.g. "/tmp/estimator.ckpt" or "hdfs:///tmp/estimator.ckpt"
:return: None
"""
from bigdl.dllib.utils.file_utils import is_local_path
if is_local_path(model_path):
self.load(model_path)
else:
results = [
worker.load_checkpoint.remote(model_path)
for worker in self.remote_workers
]
ray.get(results)
[docs] def shutdown(self, force=False):
"""
Shuts down workers and releases resources.
:return:
"""
if not force:
cleanup = [
worker.shutdown.remote() for worker in self.remote_workers
]
try:
ray.get(cleanup)
[
worker.__ray_terminate__.remote()
for worker in self.remote_workers
]
except RayActorError:
logger.warning(
"Failed to shutdown gracefully, forcing a shutdown.")
for worker in self.remote_workers:
logger.warning("Killing worker {}.".format(worker))
ray.kill(worker)
else:
for worker in self.remote_workers:
logger.debug("Killing worker {}.".format(worker))
ray.kill(worker)
self.remote_workers = []
def _process_stats(self, worker_stats):
stats = {
"num_samples": sum(
stats.pop("num_samples", np.nan) for stats in worker_stats)
}
for stat_key in worker_stats[0]:
if isinstance(worker_stats[0], numbers.Number):
stats[stat_key] = np.nanmean(
[s.get(stat_key, np.nan) for s in worker_stats])
else:
stats[stat_key] = worker_stats[0][stat_key]
return stats
def _train_epochs(self, data_creator,
epochs=1, batch_size=32,
profile=False, info=None,
callbacks=None):
params = dict(data_creator=data_creator, epochs=epochs,
batch_size=batch_size, profile=profile, info=info,
callbacks=callbacks)
remote_worker_stats = []
for i, w in enumerate(self.remote_workers):
stats = w.train_epochs.remote(**params)
remote_worker_stats.append(stats)
success = check_for_failure(remote_worker_stats)
if success:
return success, ray.get(remote_worker_stats)
else:
return success, None
def _predict_spark_xshards(self, xshards, param):
ray_xshards = RayXShards.from_spark_xshards(xshards)
def transform_func(worker, shards_ref):
data_creator = lambda config, batch_size: shards_ref
return worker.predict.remote(
data_creator, **param)
pred_shards = ray_xshards.transform_shards_with_actors(self.remote_workers,
transform_func)
spark_xshards = pred_shards.to_spark_xshards()
return spark_xshards
[docs] def get_state_dict(self):
stream_ids = [
worker.get_state_stream.remote()
for worker in self.remote_workers
]
# get the first task id that finished executing.
[stream_id], stream_ids = ray.wait(stream_ids, num_returns=1, timeout=None)
byte_obj = ray.get(stream_id)
_buffer = io.BytesIO(byte_obj)
state_dict = torch.load(
_buffer,
map_location="cpu")
return state_dict
[docs] def load_state_dict(self, state_dict, blocking=True):
_buffer = io.BytesIO()
torch.save(state_dict, _buffer)
state_stream = _buffer.getvalue()
state_id = ray.put(state_stream)
remote_calls = [
worker.load_state_stream.remote(state_id)
for worker in self.remote_workers
]
if blocking:
ray.get(remote_calls)