Source code for bigdl.orca.learn.tf2.estimator

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import numpy as np

from bigdl.orca.learn.utils import get_latest_checkpoint
from bigdl.dllib.utils.log4Error import invalidInputError


logger = logging.getLogger(__name__)


[docs]class Estimator(object):
[docs] @staticmethod def from_keras(*, model_creator, config=None, verbose=False, workers_per_node=1, compile_args_creator=None, backend="ray", cpu_binding=False, log_to_driver=True, model_dir=None, **kwargs ): """ Create an Estimator for tensorflow 2. :param model_creator: (dict -> Model) This function takes in the `config` dict and returns a compiled TF model. :param config: (dict) configuration passed to 'model_creator', 'data_creator'. Also contains `fit_config`, which is passed into `model.fit(data, **fit_config)` and `evaluate_config` which is passed into `model.evaluate`. :param verbose: (bool) Prints output of one model if true. :param workers_per_node: (Int) worker number on each node. default: 1. :param compile_args_creator: (dict -> dict of loss, optimizer and metrics) Only used when the backend="horovod". This function takes in the `config` dict and returns a dictionary like {"optimizer": tf.keras.optimizers.SGD(lr), "loss": "mean_squared_error", "metrics": ["mean_squared_error"]} :param backend: (string) You can choose "horovod", "ray" or "spark" as backend. Default: `ray`. :param cpu_binding: (bool) Whether to binds threads to specific CPUs. Default: False :param log_to_driver: (bool) Whether display executor log on driver in cluster mode. Default: True. This option is only for "spark" backend. :param model_dir: (str) The directory to save model states. It is required for "spark" backend. For cluster mode, it should be a share filesystem path which can be accessed by executors. """ if backend in {"ray", "horovod"}: from bigdl.orca.learn.tf2.ray_estimator import TensorFlow2Estimator return TensorFlow2Estimator(model_creator=model_creator, config=config, verbose=verbose, workers_per_node=workers_per_node, backend=backend, compile_args_creator=compile_args_creator, cpu_binding=cpu_binding) elif backend == "spark": if cpu_binding: invalidInputError(False, "cpu_binding should not be True when using spark backend") from bigdl.orca.learn.tf2.pyspark_estimator import SparkTFEstimator return SparkTFEstimator(model_creator=model_creator, config=config, verbose=verbose, compile_args_creator=compile_args_creator, workers_per_node=workers_per_node, log_to_driver=log_to_driver, model_dir=model_dir, **kwargs) else: invalidInputError(False, "Only horovod, ray and spark backends are supported" f" for now, got backend: {backend}")
[docs] @staticmethod def latest_checkpoint(checkpoint_dir): return get_latest_checkpoint(checkpoint_dir)
[docs]def make_data_creator(refs): def data_creator(config, batch_size): return refs return data_creator
[docs]def data_length(data): x = data["x"] if isinstance(x, np.ndarray): return x.shape[0] else: return x[0].shape[0]