EVOLUTION-MANAGER
Edit File: loss_scale.py
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains keras-specific LossScale functionality. This functions cannot be in the non-keras loss_scale.py file since they depend on keras, and files outside of keras should not depend on files inside keras. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import six from tensorflow.python.keras.utils import generic_utils from tensorflow.python.training.experimental import loss_scale as loss_scale_module def serialize(loss_scale): return generic_utils.serialize_keras_object(loss_scale) def deserialize(config, custom_objects=None): loss_scale_module_objects = { 'FixedLossScale': loss_scale_module.FixedLossScale, 'DynamicLossScale': loss_scale_module.DynamicLossScale, } return generic_utils.deserialize_keras_object( config, module_objects=loss_scale_module_objects, custom_objects=custom_objects, printable_module_name='loss scale' ) def get(identifier): """Get a loss scale object.""" if isinstance(identifier, dict): return deserialize(identifier) if isinstance(identifier, six.integer_types + (float,)): return loss_scale_module.FixedLossScale(identifier) if identifier == 'dynamic': return loss_scale_module.DynamicLossScale() if isinstance(identifier, loss_scale_module.LossScale): return identifier elif identifier is None: return None else: raise ValueError('Could not interpret loss scale identifier: %s' % identifier)