Orca Learn#





Orca TF2Estimator with backend of “horovod” or “ray”.


Orca TF2Estimator with backend of “spark”.


class bigdl.orca.learn.pytorch.estimator.Estimator[source]#

Bases: object

static from_torch(*, model: Optional[Union[Module, Callable[[Dict], Module]]] = None, optimizer: Optional[Union[Optimizer, Callable[[Module, Dict], Optimizer]]] = None, loss: Optional[Union[Loss, Callable[[Dict], Loss]]] = None, metrics: Optional[Union[Metric, List[Metric]]] = None, backend: str = 'spark', config: Optional[Dict] = None, workers_per_node: int = 1, scheduler_creator: Optional[Callable[[Dict], LRScheduler]] = None, scheduler_step_freq: str = 'epoch', use_tqdm: bool = False, model_dir: Optional[str] = None, sync_stats: bool = False, log_level: int = 20, log_to_driver: bool = True) Optional[Union[PyTorchRayEstimator, PyTorchSparkEstimator, PyTorchPySparkEstimator]][source]#

Create an Estimator for PyTorch.

  • model – A model creator function that takes the parameter “config” and returns a PyTorch model.

  • optimizer – An optimizer creator function that has two parameters “model” and “config” and returns a PyTorch optimizer. Default: None if training is not performed.

  • loss – An instance of PyTorch loss. Default: None if loss computation is not needed.

  • metrics – One or a list of Orca validation metrics. Function(s) that computes the metrics between the output and target tensors are also supported. Default: None if no validation is involved.

  • backend – The distributed backend for the Estimator. One of “spark”, “ray”, “bigdl” or “horovod”. Default: “spark”.

  • config – A parameter config dict, CfgNode or any class instance that plays a role of configuration to create model, loss, optimizer, scheduler and data. Default: None if no config is needed.

  • workers_per_node – The number of PyTorch workers on each node. Default: 1.

  • scheduler_creator – A scheduler creator function that has two parameters “optimizer” and “config” and returns a PyTorch learning rate scheduler wrapping the optimizer. Note that if you specify this parameter, you need to take care of the argument scheduler_step_freq accordingly as well. Default: None if no scheduler is needed.

  • scheduler_step_freq – The frequency when scheduler.step is called. “batch” or “epoch” if there is a scheduler. Default: “epoch”.

  • use_tqdm – Whether to use tqdm to monitor the training progress. Default: False.

  • model_dir – The path to save the PyTorch model during the training if checkpoint_trigger is defined and triggered. Default: None.

  • sync_stats – Whether to sync metrics across all distributed workers after each epoch. If set to False, only the metrics of the worker with rank 0 are printed. Default: True

  • log_level – The log_level of each distributed worker. Default: logging.INFO.

  • log_to_driver – Whether to display executor log on driver in cluster mode for spark backend. Default: True.


A Estimator object for PyTorch.

static latest_checkpoint(checkpoint_dir: str) str[source]#


Orca Pytorch Estimator with backend of “horovod” or “ray”.


Orca Pytorch Estimator with backend of “bigdl”.