EVOLUTION-MANAGER
Edit File: function_deserialization.py
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tools for deserializing `Function`s.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import re from absl import logging from tensorflow.core.framework import function_pb2 from tensorflow.core.protobuf import saved_object_graph_pb2 from tensorflow.python.eager import def_function from tensorflow.python.eager import function as function_lib from tensorflow.python.framework import func_graph as func_graph_lib from tensorflow.python.framework import function_def_to_graph as function_def_lib from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec from tensorflow.python.ops import resource_variable_ops from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect def _is_tensor(t): return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) # TODO(edloper): Update this to just use ConcreteFunction.__call__ with the # structured signature. def _call_concrete_function(function, inputs): """Calls a restored Function with structured inputs. This differs from `function.__call__` in that inputs and outputs are structured and that it casts inputs to tensors if needed. Note: this does not checks that non-tensor inputs match. That should be done before via `_concrete_function_callable_with`. Args: function: ConcreteFunction to call. inputs: Structured inputs compatible with `function.graph.structured_input_signature`. Returns: The structured function output. """ expected_structure = function.graph.structured_input_signature flatten_inputs = nest.flatten_up_to( expected_structure, inputs, expand_composites=True) flatten_expected = nest.flatten(expected_structure, expand_composites=True) tensor_inputs = [] for arg, expected in zip(flatten_inputs, flatten_expected): if isinstance(expected, tensor_spec.TensorSpec): tensor_inputs.append( ops.convert_to_tensor(arg, dtype_hint=expected.dtype)) result = function._call_flat(tensor_inputs, function._captured_inputs) # pylint: disable=protected-access if isinstance(result, ops.Operation): return None return result def _try_convert_to_tensor_spec(arg, dtype_hint): """Returns None or TensorSpec obtained if `arg` is converted to tensor.""" try: # Note: try conversion in a FuncGraph to avoid polluting current context. with func_graph_lib.FuncGraph(name="guess_conversion").as_default(): result = ops.convert_to_tensor(arg, dtype_hint=dtype_hint) return tensor_spec.TensorSpec(shape=result.shape, dtype=result.dtype) except (TypeError, ValueError): return None def _concrete_function_callable_with(function, inputs, allow_conversion): """Returns whether concrete `function` can be called with `inputs`.""" expected_structure = function.graph.structured_input_signature try: flatten_inputs = nest.flatten_up_to(expected_structure, inputs) except (TypeError, ValueError): return False for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)): if isinstance(expected, tensor_spec.TensorSpec): if allow_conversion: arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype) if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec): return False if arg.dtype != expected.dtype: return False if not expected.shape.is_compatible_with(arg.shape): return False elif isinstance(expected, type_spec.TypeSpec): if not expected.is_compatible_with(arg): return False elif _is_tensor(arg): if id(arg) != id(expected): return False else: if arg != expected: return False return True def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder): """Deserialize a FunctionSpec object from its proto representation.""" typeless_fullargspec = coder.decode_proto(function_spec_proto.fullargspec) # Convert a method function into a non method. if function_spec_proto.is_method: if not typeless_fullargspec.args: raise NotImplementedError( "Missing support to deserialize a method function without a named " "'self' argument.") args = typeless_fullargspec.args[1:] else: args = typeless_fullargspec.args fullargspec = tf_inspect.FullArgSpec( args=args, varargs=typeless_fullargspec.varargs, varkw=typeless_fullargspec.varkw, defaults=typeless_fullargspec.defaults, kwonlyargs=typeless_fullargspec.kwonlyargs, kwonlydefaults=typeless_fullargspec.kwonlydefaults, annotations=typeless_fullargspec.annotations) input_signature = coder.decode_proto(function_spec_proto.input_signature) # See `tf.function` and the ExperimentalCompile proto for details. experimental_compile = { saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.DEFAULT: None, saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.ON: True, saved_object_graph_pb2.FunctionSpec.ExperimentalCompile.OFF: False, }.get(function_spec_proto.experimental_compile) return function_lib.FunctionSpec(fullargspec=fullargspec, is_method=False, input_signature=input_signature, experimental_compile=experimental_compile) # TODO(allenl): The fact that we can't derive ConcreteFunction calling # conventions from the serialized input spec right now is unfortunate. Merging # these would be good, maybe by adding TensorSpec names to cache keys so renamed # keyword arguments would yield different ConcreteFunctions. def setup_bare_concrete_function(saved_bare_concrete_function, concrete_functions): """Makes a restored bare concrete function callable.""" concrete_function = concrete_functions[ saved_bare_concrete_function.concrete_function_name] # pylint: disable=protected-access concrete_function._arg_keywords = ( saved_bare_concrete_function.argument_keywords) concrete_function._num_positional_args = ( saved_bare_concrete_function.allowed_positional_arguments) if saved_bare_concrete_function.HasField("function_spec"): coder = nested_structure_coder.StructureCoder() function_spec = _deserialize_function_spec_as_nonmethod( saved_bare_concrete_function.function_spec, coder) concrete_function._set_function_spec(function_spec) # pylint: enable=protected-access concrete_function.add_to_graph() return concrete_function class RestoredFunction(def_function.Function): """Wrapper class for a function that has been restored from saved state. See `def_function.Function`. """ def __init__(self, python_function, name, function_spec, concrete_functions): # TODO(mdan): We may enable autograph once exceptions are supported. super(RestoredFunction, self).__init__( python_function, name, autograph=False, experimental_compile=function_spec.experimental_compile) self.concrete_functions = concrete_functions self._function_spec = function_spec def _list_all_concrete_functions_for_serialization(self): return self.concrete_functions def _defun_with_scope(self, scope): func = super(RestoredFunction, self)._defun_with_scope(scope) func._function_spec = self._function_spec # pylint: disable=protected-access return func def recreate_function(saved_function, concrete_functions): """Creates a `Function` from a `SavedFunction`. Args: saved_function: `SavedFunction` proto. concrete_functions: map from function name to `ConcreteFunction`. As a side effect of this function, the `FunctionSpec` from `saved_function` is added to each `ConcreteFunction` in this map. Returns: A `Function`. """ # TODO(andresp): Construct a `Function` with the cache populated # instead of creating a new `Function` backed by a Python layer to # glue things together. Current approach is nesting functions deeper for each # serialization cycle. coder = nested_structure_coder.StructureCoder() # Note: handling method functions is tricky since make_decorator does not # allows control of "ismethod". Additionally since restored functions do # not behave as methods i.e. they always use the same captured tensors # independent of the object they are bound to, there is little value on # propagating that correctly. # # Ideally this conversion should happen at serialization time. But since # there are SavedModels which have "ismethod" populated and have an extra # argument that they expect to be ignored, we do it at deserialization. function_spec = _deserialize_function_spec_as_nonmethod( saved_function.function_spec, coder) def restored_function_body(*args, **kwargs): """Calls a restored function or raises an error if no matching function.""" if not saved_function.concrete_functions: raise ValueError("Found zero restored functions for caller function.") # This is the format of function.graph.structured_input_signature. At this # point, the args and kwargs have already been canonicalized. inputs = (args, kwargs) # First try to find a concrete function that can be called without input # conversions. This allows one to pick a more specific trace in case there # was also a more expensive one that supported tensors. for allow_conversion in [False, True]: for function_name in saved_function.concrete_functions: function = concrete_functions[function_name] if _concrete_function_callable_with(function, inputs, allow_conversion): return _call_concrete_function(function, inputs) signature_descriptions = [] def _pretty_format_positional(positional): return "Positional arguments ({} total):\n * {}".format( len(positional), "\n * ".join(str(a) for a in positional)) for index, function_name in enumerate(saved_function.concrete_functions): concrete_function = concrete_functions[function_name] positional, keyword = concrete_function.structured_input_signature signature_descriptions.append( "Option {}:\n {}\n Keyword arguments: {}" .format(index + 1, _pretty_format_positional(positional), keyword)) raise ValueError( "Could not find matching function to call loaded from the SavedModel. " "Got:\n {}\n Keyword arguments: {}\n\nExpected " "these arguments to match one of the following {} option(s):\n\n{}" .format(_pretty_format_positional(args), kwargs, len(saved_function.concrete_functions), "\n\n".join(signature_descriptions))) concrete_function_objects = [] for concrete_function_name in saved_function.concrete_functions: concrete_function_objects.append(concrete_functions[concrete_function_name]) for cf in concrete_function_objects: cf._set_function_spec(function_spec) # pylint: disable=protected-access restored_function = RestoredFunction( restored_function_body, restored_function_body.__name__, function_spec, concrete_function_objects) return tf_decorator.make_decorator( restored_function_body, restored_function, decorator_argspec=function_spec.fullargspec) def load_function_def_library(library, load_shared_name_suffix=None): """Load a set of functions as concrete functions without captured inputs. Functions names are manipulated during load such that they do not overlap with previously created ones. Args: library: FunctionDefLibrary proto message. load_shared_name_suffix: If specified, used to uniquify shared names. Otherwise, a unique name is generated. Returns: Map of original function names in the library to instances of `ConcreteFunction` without captured inputs. Raises: ValueError: if functions dependencies have a cycle. """ library_function_names = set(fdef.signature.name for fdef in library.function) functions = {} renamed_functions = {} # Our graph building code currently requires functions to be registered with # some tf.Graph in order to import functions using the # op-name-is-function-name calling convention. To avoid leaking memory into # the global default graph when executing eagerly, we create a temporary # Graph. # # TODO(allenl): Make this Graph creation unnecessary when executing eagerly by # fixing function_def_to_graph_def. if ops.executing_eagerly_outside_functions(): graph = ops.Graph() else: graph = ops.get_default_graph() if load_shared_name_suffix is None: load_shared_name_suffix = "_load_{}".format(ops.uid()) for fdef in _sort_function_defs(library, library_function_names): copy = _fix_fdef(fdef, functions, load_shared_name_suffix) # There is no need to copy all functions into the function def graph. It # leads to a O(n^2) increase of memory when importing functions and the # extra function definitions are a no-op since they already imported as a # function before and passed in explicitly (due to the topologic sort # import). with graph.as_default(): func_graph = function_def_lib.function_def_to_graph(copy) _restore_gradient_functions(func_graph, renamed_functions) for dep in _list_function_deps(fdef, library_function_names): functions[dep].add_to_graph(func_graph) # We do not initialize the new ConcreteFunction's function_spec and/or # arg_keywords here (which are used to parse the structured and flat # signatures, respectively). ConcreteFunction that are part of a saved # function is set up later by recreate_function(); and bare ConcreteFunction # is set up by by setup_bare_concrete_function(). func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(graph) functions[fdef.signature.name] = func renamed_functions[func.name] = func if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()): # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration # is fixed. Currently it's leaking memory to maintain bug compatibility # with previous behavior. func.add_to_graph(ops.get_default_graph()) return functions def _restore_gradient_functions(func_graph, renamed_functions): """Populate function op's _gradient_function with default gradient.""" for op in func_graph.get_operations(): # TODO(andresp): This code assumes that the gradient registered for this # function call is the default gradient for the function and not a custom # one. if op.type in ["StatefulPartitionedCall", "PartitionedCall"]: function = renamed_functions[compat.as_bytes( op.node_def.attr["f"].func.name)] op._gradient_function = function._get_gradient_function() # pylint: disable=protected-access def _sort_function_defs(library, library_function_names): """Return a topologic sort of FunctionDefs in a library.""" edges = collections.defaultdict(list) in_count = collections.defaultdict(lambda: 0) for fdef in library.function: for dep in _list_function_deps(fdef, library_function_names): edges[dep].append(fdef.signature.name) in_count[fdef.signature.name] += 1 ready = [ fdef.signature.name for fdef in library.function if in_count[fdef.signature.name] == 0 ] output = [] while ready: node = ready.pop() output.append(node) for dest in edges[node]: in_count[dest] -= 1 if not in_count[dest]: ready.append(dest) if len(output) != len(library.function): failed_to_resolve = sorted(set(in_count.keys()) - set(output)) raise ValueError("There is a cyclic-dependency between functions. ", "Could not resolve %r." % (failed_to_resolve,)) reverse = {fdef.signature.name: fdef for fdef in library.function} return [reverse[x] for x in output] def fix_node_def(node_def, functions, shared_name_suffix, debug_name): """Replace functions calls and shared names in `node_def`.""" if ("_gradient_op_type" in node_def.attr and node_def.op not in ["StatefulPartitionedCall", "PartitionedCall"]): logging.warning( "Importing a function (%s) with ops with custom gradients. Will likely " "fail if a gradient is requested.", debug_name) if node_def.op in functions: node_def.op = functions[node_def.op].name for _, attr_value in node_def.attr.items(): if attr_value.WhichOneof("value") == "func": attr_value.func.name = functions[attr_value.func.name].name elif attr_value.WhichOneof("value") == "list": for fn in attr_value.list.func: fn.name = functions[fn.name].name # Fix old table creation bug. if node_def.op == "HashTableV2": if ("use_node_name_sharing" not in node_def.attr or not node_def.attr["use_node_name_sharing"].b): node_def.attr["use_node_name_sharing"].b = True # We are turning on node mame sharing, so have to make sure we don't # accidentally share a table resource. shared_name_suffix += "_{}".format(ops.uid()) # TODO(b/124205571): Avoid accidental sharing and destruction of restored # resources. For now uniquify "shared_name" when loading functions to avoid # sharing. # TODO: Add regression test for b/150826922. op_def = op_def_registry.get(node_def.op) if op_def: attr = next((a for a in op_def.attr if a.name == "shared_name"), None) if attr: shared_name = None if "shared_name" in node_def.attr and node_def.attr["shared_name"].s: shared_name = node_def.attr["shared_name"].s elif attr.default_value.s: shared_name = compat.as_bytes(attr.default_value.s) if not shared_name: shared_name = compat.as_bytes(node_def.name) node_def.attr["shared_name"].s = ( shared_name + compat.as_bytes(shared_name_suffix)) def _fix_fdef(orig_fdef, functions, shared_name_suffix): """Fixes a FunctionDef proto to be loaded in current context. In particular, when loading a function library into an eager context, one must rename the functions to avoid conflicts with existent functions. Args: orig_fdef: FunctionDef proto to fix. It is not modified. functions: map from function name to a ConcreteFunction instance. shared_name_suffix: A unique string for this load which helps to avoid `shared_name` collisions across loads. Two functions from the same load using the same `shared_name` still need to share, but functions from different loads with the same `shared_name` should not. Returns: A fixed copy of the original FunctionDef. """ fdef = function_pb2.FunctionDef() fdef.CopyFrom(orig_fdef) for node_def in fdef.node_def: fix_node_def(node_def, functions, shared_name_suffix, fdef.signature.name) fdef.signature.name = _clean_function_name(fdef.signature.name) return fdef def _list_function_deps(fdef, library_function_names): """Find functions referenced in `fdef`.""" # TODO(andresp): Recurse into list attributes and into NameAttrList attrs both # when listing deps and when fixing them. `function_def_to_graph` also # requires fixes. deps = set() for node_def in fdef.node_def: if node_def.op in library_function_names: deps.add(node_def.op) else: for _, attr_value in node_def.attr.items(): if attr_value.WhichOneof("value") == "func": deps.add(attr_value.func.name) elif attr_value.WhichOneof("value") == "list": for fn in attr_value.list.func: deps.add(fn.name) return deps _FUNCTION_WRAPPER_NAME_REGEX = r"^%s(.*)_\d+$" % (function_lib._INFERENCE_PREFIX ) # pylint:disable=protected-access def _clean_function_name(name): """Vanity function to keep the function names comprehensible.""" # Note: each time a function is wrapped into `function_lib.ConcreteFunction` # its name becomes "__inference_<orig>_xyz". match = re.search(_FUNCTION_WRAPPER_NAME_REGEX, name) if match: return match.group(1) else: return name