Source code for bigdl.orca.learn.pytorch.estimator

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
from bigdl.orca.learn.pytorch.training_operator import TrainingOperator


[docs]class Estimator(object):
[docs] @staticmethod def from_torch(*, model, optimizer, loss=None, metrics=None, scheduler_creator=None, training_operator_cls=TrainingOperator, initialization_hook=None, config=None, scheduler_step_freq="batch", use_tqdm=False, workers_per_node=1, model_dir=None, backend="bigdl", sync_stats=False, log_level=logging.INFO, log_to_driver=True, ): """ Create an Estimator for torch. :param model: PyTorch model or model creator function if backend="bigdl", PyTorch model creator function if backend="horovod" or "torch_distributed" :param optimizer: Orca/PyTorch optimizer or optimizer creator function if backend="bigdl" , PyTorch optimizer creator function if backend="horovod" or "torch_distributed" :param loss: PyTorch loss or loss creator function if backend="bigdl", PyTorch loss creator function if backend="horovod" or "torch_distributed" :param metrics: Orca validation methods for evaluate. :param scheduler_creator: parameter for `horovod` and `torch_distributed` backends. a learning rate scheduler wrapping the optimizer. You will need to set ``scheduler_step_freq="epoch"`` for the scheduler to be incremented correctly. :param config: parameter config dict to create model, optimizer loss and data. :param scheduler_step_freq: parameter for `horovod` and `torch_distributed` backends. "batch", "epoch" or None. This will determine when ``scheduler.step`` is called. If "batch", ``step`` will be called after every optimizer step. If "epoch", ``step`` will be called after one pass of the DataLoader. If a scheduler is passed in, this value is expected to not be None. :param use_tqdm: parameter for `horovod` and `torch_distributed` backends. You can monitor training progress if use_tqdm=True. :param workers_per_node: parameter for `horovod` and `torch_distributed` backends. worker number on each node. default: 1. :param model_dir: parameter for `bigdl` and `spark` backend. The path to save model. During the training, if checkpoint_trigger is defined and triggered, the model will be saved to model_dir. :param backend: You can choose "horovod", "torch_distributed", "bigdl" or "spark" as backend. Default: `bigdl`. :param sync_stats: Whether to sync metrics across all distributed workers after each epoch. If set to False, only rank 0's metrics are printed. This param only works horovod, torch_distributed and pyspark backend. For spark backend, the metrics printed are are always synced. This param only affects the printed metrics, the returned metrics are always averaged across workers. Default: True :param log_level: Setting the log_level of each distributed worker. This param only works horovod, torch_distributed and pyspark backend. :param log_to_driver: (bool) Whether display executor log on driver in cluster mode. Default: True. This option is only for "spark" backend. :return: an Estimator object. """ if backend in {"horovod", "torch_distributed"}: from bigdl.orca.learn.pytorch.pytorch_ray_estimator import PyTorchRayEstimator return PyTorchRayEstimator(model_creator=model, optimizer_creator=optimizer, loss_creator=loss, metrics=metrics, scheduler_creator=scheduler_creator, training_operator_cls=training_operator_cls, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, use_tqdm=use_tqdm, workers_per_node=workers_per_node, backend=backend, sync_stats=sync_stats, log_level=log_level) elif backend == "bigdl": from bigdl.orca.learn.pytorch.pytorch_spark_estimator import PyTorchSparkEstimator return PyTorchSparkEstimator(model=model, loss=loss, optimizer=optimizer, config=config, metrics=metrics, model_dir=model_dir, bigdl_type="float") elif backend == "spark": from bigdl.orca.learn.pytorch.pytorch_pyspark_estimator import PyTorchPySparkEstimator return PyTorchPySparkEstimator(model_creator=model, optimizer_creator=optimizer, loss_creator=loss, metrics=metrics, scheduler_creator=scheduler_creator, training_operator_cls=training_operator_cls, initialization_hook=initialization_hook, config=config, scheduler_step_freq=scheduler_step_freq, use_tqdm=use_tqdm, workers_per_node=workers_per_node, sync_stats=sync_stats, log_level=log_level, model_dir=model_dir, log_to_driver=log_to_driver, ) else: raise ValueError("Only horovod, torch_distributed, bigdl and spark backends are " f"supported for now, got backend: {backend}")
[docs] @staticmethod def latest_checkpoint(checkpoint_dir): from .callbacks.model_checkpoint import ModelCheckpoint checkpoint_path = ModelCheckpoint.get_latest_checkpoint(checkpoint_dir) return checkpoint_path