Source code for bigdl.nano.automl.tf.keras.Model

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import tensorflow as tf

from bigdl.nano.automl.utils import proxy_methods
from bigdl.nano.automl.tf.mixin import HPOMixin
from bigdl.nano.automl.hpo.callgraph import CallCache


[docs]@proxy_methods class Model(HPOMixin, tf.keras.Model): """Tf.keras.Model with HPO capabilities.""" def __init__(self, **kwargs): """Initializer.""" # we only take keyword arguments for now # TODO check how args is used # TODO check why base class is keras.engine.training_v1.Model super().__init__() self.model_class = tf.keras.Model self.kwargs = kwargs self.lazyinputs_ = kwargs.get('inputs', None) self.lazyoutputs_ = kwargs.get('outputs', None) def _model_init_args(self, trial): # for lazy model init # use backend to sample model init args # and construct the actual layers in_tensors, out_tensors = CallCache.execute( self.lazyinputs_, self.lazyoutputs_, trial, self.backend) self.kwargs['inputs'] = in_tensors self.kwargs['outputs'] = out_tensors return self.kwargs def _get_model_init_args_func_kwargs(self): """Return the kwargs of _model_init_args_func except trial.""" return { 'lazyinputs': self.lazyinputs_, 'lazyoutputs': self.lazyoutputs_, 'kwargs': self.kwargs, 'backend': self.backend } @staticmethod def _model_init_args_func(trial, lazyinputs, lazyoutputs, kwargs, backend): in_tensors, out_tensors = CallCache.execute( lazyinputs, lazyoutputs, trial, backend) kwargs['inputs'] = in_tensors kwargs['outputs'] = out_tensors return kwargs