EVOLUTION-MANAGER
Edit File: utils.py
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utility functions shared between SavedModel saving/loading implementations.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import itertools import threading import types from tensorflow.python.eager import context from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.util import tf_decorator # pylint:disable=g-inconsistent-quotes training_lib = LazyLoader( "training_lib", globals(), "tensorflow.python.keras.engine.training") # pylint:enable=g-inconsistent-quotes def use_wrapped_call(layer, call_fn, default_training_value=None, return_method=False): """Creates fn that adds the losses returned by call_fn & returns the outputs. Args: layer: A Keras layer object call_fn: tf.function that takes layer inputs (and possibly a training arg), and returns a tuple of (outputs, list of losses). default_training_value: Default value of the training kwarg. If `None`, the default is `K.learning_phase()`. return_method: Whether to return a method bound to the layer. Returns: function that calls call_fn and returns the outputs. Losses returned by call_fn are added to the layer losses. """ expects_training_arg = layer_uses_training_bool(layer) if hasattr(call_fn, 'original_call'): # call_fn is a LayerCall object original_call = call_fn.original_call # In Python 3, callable objects are not compatible with inspect.getargspec call_fn = call_fn.__call__ else: original_call = call_fn fn, arg_spec = maybe_add_training_arg( original_call, call_fn, expects_training_arg, default_training_value) def return_outputs_and_add_losses(*args, **kwargs): """Returns the outputs from the call_fn, and adds the losses.""" inputs_arg_index = 1 if return_method else 0 inputs = args[inputs_arg_index] args = args[inputs_arg_index + 1:] outputs, losses = fn(inputs, *args, **kwargs) layer.add_loss(losses, inputs=inputs) # TODO(kathywu): This is a temporary hack. When a network of layers is # revived from SavedModel, only the top-level layer will have losses. This # causes issues in eager mode because the child layers may have graph losses # (thus model.losses returns a mix of Eager and graph tensors). To fix this, # whenever eager losses are added to one layer, add eager losses to all # child layers. This causes `.losses` to only return eager losses. # pylint: disable=protected-access if context.executing_eagerly(): for i in layer._flatten_layers(): if i is not layer: i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER] # pylint: enable=protected-access return outputs decorated = tf_decorator.make_decorator( target=call_fn, decorator_func=return_outputs_and_add_losses, decorator_argspec=arg_spec) if return_method: return types.MethodType(decorated, layer) else: return decorated def layer_uses_training_bool(layer): """Returns whether this layer or any of its children uses the training arg.""" if layer._expects_training_arg: # pylint: disable=protected-access return True visited = {layer} to_visit = list_all_layers(layer) while to_visit: layer = to_visit.pop() if layer in visited: continue if getattr(layer, '_expects_training_arg', True): return True visited.add(layer) to_visit.extend(list_all_layers(layer)) return False def list_all_layers(obj): if isinstance(obj, training_lib.Model): return obj.layers else: return list( layer_utils.filter_empty_layer_containers(obj._layers)) # pylint: disable=protected-access def list_all_layers_and_sublayers(obj): s = set([obj]) s.update(itertools.chain.from_iterable( list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj))) return s def maybe_add_training_arg( original_call, wrapped_call, expects_training_arg, default_training_value): """Decorate call and optionally adds training argument. If a layer expects a training argument, this function ensures that 'training' is present in the layer args or kwonly args, with the default training value. Args: original_call: Original call function. wrapped_call: Wrapped call function. expects_training_arg: Whether to include 'training' argument. default_training_value: Default value of the training kwarg to include in the arg spec. If `None`, the default is `K.learning_phase()`. Returns: Tuple of ( function that calls `wrapped_call` and sets the training arg, Argspec of returned function or `None` if the argspec is unchanged) """ if not expects_training_arg: return wrapped_call, None def wrap_with_training_arg(*args, **kwargs): """Wrap the `wrapped_call` function, and set training argument.""" training_arg_index = get_training_arg_index(original_call) training = get_training_arg(training_arg_index, args, kwargs) if training is None: training = default_training_value or K.learning_phase() args = list(args) kwargs = kwargs.copy() def replace_training_and_call(training): set_training_arg(training, training_arg_index, args, kwargs) return wrapped_call(*args, **kwargs) return control_flow_util.smart_cond( training, lambda: replace_training_and_call(True), lambda: replace_training_and_call(False)) # Create arg spec for decorated function. If 'training' is not defined in the # args of the original arg spec, then add it to kwonlyargs. arg_spec = tf_inspect.getfullargspec(original_call) defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else [] kwonlyargs = arg_spec.kwonlyargs kwonlydefaults = arg_spec.kwonlydefaults or {} # Add training arg if it does not exist, or set the default training value. if 'training' not in arg_spec.args: kwonlyargs.append('training') kwonlydefaults['training'] = default_training_value else: index = arg_spec.args.index('training') training_default_index = len(arg_spec.args) - index if (arg_spec.defaults and len(arg_spec.defaults) >= training_default_index and defaults[-training_default_index] is None): defaults[-training_default_index] = default_training_value decorator_argspec = tf_inspect.FullArgSpec( args=arg_spec.args, varargs=arg_spec.varargs, varkw=arg_spec.varkw, defaults=defaults, kwonlyargs=kwonlyargs, kwonlydefaults=kwonlydefaults, annotations=arg_spec.annotations) return wrap_with_training_arg, decorator_argspec def get_training_arg_index(call_fn): """Returns the index of 'training' in the layer call function arguments. Args: call_fn: Call function. Returns: - n: index of 'training' in the call function arguments. - -1: if 'training' is not found in the arguments, but layer.call accepts variable keyword arguments - None: if layer doesn't expect a training argument. """ arg_list = tf_inspect.getfullargspec(call_fn).args if tf_inspect.ismethod(call_fn): arg_list = arg_list[1:] if 'training' in arg_list: return arg_list.index('training') else: return -1 def set_training_arg(training, index, args, kwargs): if index is None: pass elif index >= 0 and len(args) > index: args[index] = training else: kwargs['training'] = training return args, kwargs def get_training_arg(index, args, kwargs): if index is None: return None elif index >= 0 and len(args) > index: return args[index] else: return kwargs.get('training', None) def remove_training_arg(index, args, kwargs): if index is None: pass elif index >= 0 and len(args) > index: args.pop(index) else: kwargs.pop('training', None) class SaveOptionsContext(threading.local): def __init__(self): super(SaveOptionsContext, self).__init__() self.save_traces = True _save_options_context = SaveOptionsContext() @tf_contextlib.contextmanager def keras_option_scope(save_traces): previous_value = _save_options_context.save_traces try: _save_options_context.save_traces = save_traces yield finally: _save_options_context.save_traces = previous_value def should_save_traces(): return _save_options_context.save_traces