Source code for bigdl.orca.learn.pytorch.pytorch_ray_estimator

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import types
import copy
import logging

from bigdl.orca import OrcaContext
from bigdl.orca.data.ray_xshards import RayXShards
from bigdl.orca.learn.pytorch.pytorch_ray_worker import PytorchRayWorker
from bigdl.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
    convert_predict_xshards_to_dataframe, update_predict_xshards, \
    process_xshards_of_pandas_dataframe, reload_dataloader_creator
from bigdl.orca.ray import OrcaRayContext
from bigdl.orca.learn.pytorch.core.base_ray_estimator import BaseRayEstimator
from bigdl.orca.learn.pytorch.utils import process_stats, check_for_failure
from bigdl.orca.learn.pytorch.callbacks.maincallback import make_only_mainCallback

import ray
from bigdl.dllib.utils.log4Error import invalidInputError

from typing import TYPE_CHECKING, Union, Optional, Callable, Dict, List, Type
if TYPE_CHECKING:
    from torch.nn import Module
    from torch.optim import Optimizer
    from bigdl.orca.learn.metrics import Metric
    from torch.nn.modules.loss import _Loss as Loss
    from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
    from torch.distributed import TCPStore
    from torch.utils.data import DataLoader
    from pyspark.sql.dataframe import DataFrame as SparkDataFrame
    from ray.data import Dataset as RayDataset
    from bigdl.orca.learn.pytorch.callbacks import Callback
    from bigdl.orca.data import SparkXShards


def partition_refs_to_creator(partition_refs):
    def data_creator(config, batch_size):
        from bigdl.orca.data.utils import partitions_get_data_label, index_data, get_size
        from torch.utils.data import Dataset, DataLoader

        class NDArrayDataset(Dataset):
            def __init__(self, x, y):
                self.x = x  # features
                self.y = y  # labels

            def __len__(self):
                return get_size(self.y)

            def __getitem__(self, i):
                return index_data(self.x, i), index_data(self.y, i)

        params = {"batch_size": batch_size, "shuffle": True}
        for arg in ["shuffle", "sampler", "batch_sampler", "num_workers", "collate_fn",
                    "pin_memory", "drop_last", "timeout", "worker_init_fn",
                    "multiprocessing_context"]:
            if arg in config:
                params[arg] = config[arg]
        data, label = partitions_get_data_label(ray.get(partition_refs),
                                                allow_tuple=False,
                                                allow_list=False)
        print("Data size on worker: ", len(label))
        dataset = NDArrayDataset(data, label)
        data_loader = DataLoader(dataset, **params)
        return data_loader

    return data_creator


[docs]class PyTorchRayEstimator(BaseRayEstimator): def __init__( self, *, model_creator: Union[Callable[[Dict], 'Module'], None], optimizer_creator: Union[Callable[['Module', Dict], 'Optimizer'], None]=None, loss_creator: Union['Loss', Callable[[Dict], 'Loss'], None]=None, metrics: Union['Metric', List['Metric'], None]=None, scheduler_creator: Optional[Callable[[Dict], 'LRScheduler']]=None, config: Dict=None, use_tqdm: bool=False, backend: str="ray", workers_per_node: int=1, sync_stats: bool=True, log_level: int=logging.INFO): if config is not None and "batch_size" in config: invalidInputError(False, "Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate/predict function of the estimator instead.") # todo remove ray_ctx to run on workers ray_ctx = OrcaRayContext.get() if not (isinstance(model_creator, types.FunctionType) and isinstance(optimizer_creator, types.FunctionType)): # Torch model is also callable. invalidInputError(False, "Must provide a function for both model_creator and" " optimizer_creator") self.model_creator = model_creator self.optimizer_creator = optimizer_creator self.loss_creator = loss_creator self.scheduler_creator = scheduler_creator self.use_tqdm = use_tqdm self.sync_stats = sync_stats self.backend = backend self.workers_per_node = workers_per_node self.config = {} if config is None else config worker_config = copy.copy(self.config) self.setup_params = dict( model_creator=self.model_creator, optimizer_creator=self.optimizer_creator, loss_creator=self.loss_creator, scheduler_creator=self.scheduler_creator, use_tqdm=self.use_tqdm, config=worker_config, metrics=metrics, sync_stats=sync_stats, log_level=log_level ) self.setup(params=self.setup_params, backend=self.backend, runner_cls=PytorchRayWorker, workers_per_node=self.workers_per_node)
[docs] def fit(self, data: Union['SparkXShards', 'SparkDataFrame', 'RayDataset', Callable[[Dict, int], 'DataLoader']], epochs: int=1, batch_size: int=32, profile: bool=False, reduce_results: bool=True, info: Optional[Dict]=None, feature_cols: Optional[List[str]]=None, label_cols: Optional[List[str]]=None, validation_data: Union['SparkXShards', 'SparkDataFrame', Callable[[Dict, int], 'DataLoader'], None]=None, callbacks: Optional[List['Callback']]=None) -> List: """ Trains a PyTorch model given training data for several epochs. Calls `TorchRunner.train_epoch()` on N parallel workers simultaneously underneath the hood. :param data: An instance of SparkXShards, a Ray Dataset, a Spark DataFrame or a function that takes config and batch_size as argument and returns a PyTorch DataLoader for training. :param epochs: The number of epochs to train the model. Default is 1. :param batch_size: Total batch size for all workers used for training. Each worker's batch size would be this value divide the total number of workers. Default is 32. If your training data is a function, you can set batch_size to be the input batch_size of the function for the PyTorch DataLoader. :param profile: Boolean. Whether to return time stats for the training procedure. Default is False. :param reduce_results: Boolean. Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value, the one value will be randomly selected among the workers. If False, returns a list of dicts for all workers. Default is True. :param info: An optional dictionary that can be passed to the TorchRunner for train_epoch and train_batch. :param feature_cols: feature column names if data is Spark DataFrame or Ray Dataset. :param label_cols: label column names if data is Spark DataFrame or Ray Dataset. :param validation_data: validation data. Validation data type should be the same as train data. :param callbacks: A list for all callbacks. Note that only one MainCallback is allowed among all callbacks. :return: A list of dictionary of metrics for every training epoch. If reduce_results is False, this will return a nested list of metric dictionaries whose length will be equal to the total number of workers. You can also provide custom metrics by passing in a custom HookClass(after 2.2.0) when creating the Estimator. """ invalidInputError(isinstance(batch_size, int) and batch_size > 0, "batch_size should be a positive integer") batch_size = batch_size // self.num_workers # Local batch size for each worker if batch_size <= 0: batch_size = 1 # Check uniqueness of the MainCallback if not callbacks: callbacks = [] make_only_mainCallback(callbacks) params = dict( epochs=epochs, batch_size=batch_size, profile=profile, info=info, callbacks=callbacks, ) if self.backend == "ray" and not self.init_ddp_process: self.setup_torch_ddp() from bigdl.orca.data import SparkXShards from ray.data import Dataset data, validation_data = maybe_dataframe_to_xshards(data, validation_data=validation_data, feature_cols=feature_cols, label_cols=label_cols, mode="fit", num_workers=self.num_workers, shard_size=batch_size) if isinstance(data, SparkXShards): # Should not wrap DistributedSampler on DataLoader for SparkXShards input. params["wrap_dataloader"] = False if data._get_class_name() == 'pandas.core.frame.DataFrame': data, validation_data = process_xshards_of_pandas_dataframe(data, feature_cols, label_cols, validation_data, "fit") from bigdl.orca.data.utils import process_spark_xshards ray_xshards = process_spark_xshards(data, self.num_workers) # type:ignore if validation_data is None: def transform_func(worker, partition_refs): data_creator = partition_refs_to_creator(partition_refs) params["data_creator"] = data_creator return worker.train_epochs.remote(**params) worker_stats = ray_xshards.reduce_partitions_for_actors(self.remote_workers, transform_func) else: if self.backend == "horovod": invalidInputError(False, "Currently, we don't support input validation_data" " for horovod backend") val_ray_xshards = process_spark_xshards(validation_data, # type:ignore self.num_workers) def zip_func(worker, this_partition_refs, that_partition_refs): params["data_creator"] = partition_refs_to_creator(this_partition_refs) params["validation_data_creator"] = \ partition_refs_to_creator(that_partition_refs) return worker.train_epochs.remote(**params) worker_stats = ray_xshards.zip_reduce_shards_with_actors(val_ray_xshards, self.remote_workers, zip_func) elif isinstance(data, Dataset): # todo: need to refactor to align with tf2 code params["wrap_dataloader"] = False shards = data.split(n=self.num_workers, locality_hints=self.remote_workers) def make_data_creator(shard, feature_cols, label_cols): def data_creator(config, batch_size): torch_datashard = shard.to_torch(label_column=label_cols, feature_columns=feature_cols, batch_size=batch_size) return torch_datashard return data_creator remote_worker_stats = [] if validation_data is None: for shard, worker in zip(shards, self.remote_workers): params["data_creator"] = make_data_creator(shard, feature_cols, label_cols) stats = worker.train_epochs.remote(**params) remote_worker_stats.append(stats) else: if self.backend == "horovod": invalidInputError(False, "Currently, we don't support input validation_data for" " horovod backend") if not isinstance(validation_data, ray.data.Dataset): invalidInputError(False, "Validation data type should be the same as train data," " but got type: {}".format(type(validation_data))) val_shards = validation_data.split(n=self.num_workers, # type:ignore locality_hints=self.remote_workers) for shard, val_shard, worker in zip(shards, val_shards, # type:ignore self.num_workers): params["data_creator"] = make_data_creator(shard, feature_cols, label_cols) params["validation_data_creator"] = make_data_creator(val_shard, feature_cols, label_cols) stats = worker.train_epochs.remote(**params) remote_worker_stats.append(stats) success = check_for_failure(remote_worker_stats) if success: worker_stats = ray.get(remote_worker_stats) else: worker_stats = None else: invalidInputError(isinstance(data, types.FunctionType), "data should be either an instance of SparkXShards," " Ray Dataset or a callable function, but" " got type: {}".format(type(data))) params["data_creator"] = reload_dataloader_creator(data) params["validation_data_creator"] = reload_dataloader_creator(validation_data) success, worker_stats = self._train_epochs(**params) epoch_stats = list(map(list, zip(*worker_stats))) # type:ignore if reduce_results: for i in range(len(epoch_stats)): epoch_stats[i] = process_stats(epoch_stats[i]) return epoch_stats else: return epoch_stats
[docs] def predict(self, data: Union['SparkXShards', 'SparkDataFrame'], batch_size: int=32, feature_cols: Optional[List[str]]=None, profile: bool=False) -> Union['SparkXShards', 'SparkDataFrame']: """ Using this PyTorch model to make predictions on the data. :param data: An instance of SparkXShards, a Ray Dataset or a Spark DataFrame :param batch_size: Total batch size for all workers used for inference. Each worker's batch size would be this value divide the total number of workers. Default is 32. :param profile: Boolean. Whether to return time stats for the training procedure. Default is False. :param feature_cols: feature column names if data is a Spark DataFrame or Ray Dataset. :return: A SparkXShards or a list that contains the predictions with key "prediction" in each shard """ invalidInputError(isinstance(batch_size, int) and batch_size > 0, "batch_size should be a positive integer") batch_size = batch_size // self.num_workers # Local batch size for each worker if batch_size <= 0: batch_size = 1 from bigdl.orca.data import SparkXShards param = dict( batch_size=batch_size, profile=profile ) from pyspark.sql import DataFrame if isinstance(data, DataFrame): xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict", shard_size=batch_size) pred_shards = self._predict_spark_xshards(xshards, param) result = convert_predict_xshards_to_dataframe(data, pred_shards) elif isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data = process_xshards_of_pandas_dataframe(data, feature_cols) pred_shards = self._predict_spark_xshards(data, param) result = update_predict_xshards(data, pred_shards) elif isinstance(data, ray.data.Dataset): shards = data.split(n=self.num_workers, locality_hints=self.remote_workers) def data_creator(config, batch_size): torch_datashard = shard.to_torch(feature_columns=feature_cols, batch_size=batch_size) return torch_datashard remote_worker_stats = [] for shard, worker in zip(shards, self.remote_workers): worker_stats = worker.predict.remote(data_creator, batch_size, profile) remote_worker_stats.append(worker_stats) result = ray.data.from_numpy(remote_worker_stats).map( lambda r: {"prediction_result": r["value"]}) else: invalidInputError(False, "Only xshards, Spark DataFrame or Ray Dataset" " is supported for predict") return result
[docs] def evaluate(self, data: Union['SparkXShards', 'SparkDataFrame', 'RayDataset', Callable[[Dict, int], 'DataLoader']], batch_size: int=32, num_steps: int=None, profile: bool=False, reduce_results: bool=True, info: Dict=None, feature_cols: Optional[List[str]]=None, label_cols: Optional[List[str]]=None, callbacks: Optional[List['Callback']]=None) -> Union[List[Dict], Dict]: """ Evaluates a PyTorch model given validation data. Note that only accuracy for classification with zero-based label is supported by default. You can override validate_batch in TorchRunner for other metrics. Calls `TorchRunner.validate()` on N parallel workers simultaneously underneath the hood. :param data: An instance of SparkXShards, a Spark DataFrame, a Ray Dataset or a function that takes config and batch_size as argument and returns a PyTorch DataLoader for validation. :param batch_size: Total batch size for all workers used for evaluation. Each worker's batch size would be this value divide the total number of workers. Default: 32. If your validation data is a function, you can set batch_size to be the input batch_size of the function for the PyTorch DataLoader. :param num_steps: The number of batches to compute the validation results on. This corresponds to the number of times `TorchRunner.validate_batch` is called. :param profile: Boolean. Whether to return time stats for the training procedure. Default is False. :param reduce_results: Boolean. Whether to average all metrics across all workers into one dict. If a metric is a non-numerical value, the one value will be randomly selected among the workers. If False, returns a list of dicts for all workers. Default is True. :param info: An optional dictionary that can be passed to the TorchRunner for validate. :param feature_cols: feature column names if train data is Spark DataFrame or Ray Dataset. :param label_cols: label column names if train data is Spark DataFrame or Ray Dataset. :param callbacks: A list for all callbacks. Note that only one MainCallback is allowed among all callbacks. :return: A dictionary of metrics for the given data, including validation accuracy and loss. You can also provide custom metrics by passing in a custom HookClass(after 2.2.0) when creating the Estimator. """ invalidInputError(isinstance(batch_size, int) and batch_size > 0, "batch_size should be a positive integer") batch_size = batch_size // self.num_workers # Local batch size for each worker if batch_size <= 0: batch_size = 1 from bigdl.orca.data import SparkXShards data, _ = maybe_dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=label_cols, mode="evaluate", num_workers=self.num_workers, shard_size=batch_size) # Check uniqueness of the MainCallback if not callbacks: callbacks = [] make_only_mainCallback(callbacks) params = dict(batch_size=batch_size, num_steps=num_steps, profile=profile, info=info, wrap_dataloader=False, callbacks=callbacks) if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data = process_xshards_of_pandas_dataframe(data, feature_cols, label_cols) from bigdl.orca.data.utils import process_spark_xshards ray_xshards = process_spark_xshards(data, self.num_workers) # type:ignore def transform_func(worker, partition_refs): data_creator = partition_refs_to_creator(partition_refs) # Should not wrap DistributedSampler on DataLoader for SparkXShards input. params["data_creator"] = data_creator return worker.validate.remote(**params) worker_stats = ray_xshards.reduce_partitions_for_actors(self.remote_workers, transform_func) elif isinstance(data, ray.data.Dataset): shards = data.split(n=self.num_workers, locality_hints=self.remote_workers) def data_creator(config, batch_size): torch_datashard = shard.to_torch(label_column=label_cols, feature_columns=feature_cols, batch_size=batch_size) return torch_datashard remote_worker_stats = [] params["data_creator"] = data_creator # type:ignore for shard, worker in zip(shards, self.remote_workers): stats = worker.validate.remote(**params) remote_worker_stats.append(stats) worker_stats = ray.get(remote_worker_stats) else: invalidInputError(isinstance(data, types.FunctionType), "data should be either an instance of SparkXShards or a callable" " function, but got type: {}".format(type(data))) params["data_creator"] = reload_dataloader_creator(data) worker_stats = ray.get([w.validate.remote(**params) for w in self.remote_workers]) if reduce_results: return process_stats(worker_stats) else: return worker_stats
[docs] def get_model(self) -> 'Module': """ Returns the learned PyTorch model. :return: The learned PyTorch model. """ state = self.get_state_dict() model = self.model_creator(self.config) # type:ignore model_state = state["models"][0] model.load_state_dict(model_state) return model.module if hasattr(model, "module") else model # type:ignore
def _predict_spark_xshards(self, xshards, param): ray_xshards = RayXShards.from_spark_xshards(xshards) def transform_func(worker, shards_ref): data_creator = lambda config, batch_size: shards_ref return worker.predict.remote( data_creator, **param) pred_shards = ray_xshards.transform_shards_with_actors(self.remote_workers, transform_func) spark_xshards = pred_shards.to_spark_xshards() return spark_xshards