EVOLUTION-MANAGER
Edit File: extenders.py
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Extenders of tf.estimator.Estimator.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.util import function_utils from tensorflow.python.util.tf_export import estimator_export from tensorflow_estimator.python.estimator import estimator as estimator_lib from tensorflow_estimator.python.estimator.mode_keys import ModeKeys _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @estimator_export('estimator.add_metrics') def add_metrics(estimator, metric_fn): """Creates a new `tf.estimator.Estimator` which has given metrics. Example: ```python def my_auc(labels, predictions): auc_metric = tf.keras.metrics.AUC(name="my_auc") auc_metric.update_state(y_true=labels, y_pred=predictions['logistic']) return {'auc': auc_metric} estimator = tf.estimator.DNNClassifier(...) estimator = tf.estimator.add_metrics(estimator, my_auc) estimator.train(...) estimator.evaluate(...) ``` Example usage of custom metric which uses features: ```python def my_auc(labels, predictions, features): auc_metric = tf.keras.metrics.AUC(name="my_auc") auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'], sample_weight=features['weight']) return {'auc': auc_metric} estimator = tf.estimator.DNNClassifier(...) estimator = tf.estimator.add_metrics(estimator, my_auc) estimator.train(...) estimator.evaluate(...) ``` Args: estimator: A `tf.estimator.Estimator` object. metric_fn: A function which should obey the following signature: - Args: can only have following four arguments in any order: * predictions: Predictions `Tensor` or dict of `Tensor` created by given `estimator`. * features: Input `dict` of `Tensor` objects created by `input_fn` which is given to `estimator.evaluate` as an argument. * labels: Labels `Tensor` or dict of `Tensor` created by `input_fn` which is given to `estimator.evaluate` as an argument. * config: config attribute of the `estimator`. - Returns: Dict of metric results keyed by name. Final metrics are a union of this and `estimator's` existing metrics. If there is a name conflict between this and `estimator`s existing metrics, this will override the existing one. The values of the dict are the results of calling a metric function, namely a `(metric_tensor, update_op)` tuple. Returns: A new `tf.estimator.Estimator` which has a union of original metrics with given ones. """ _verify_metric_fn_args(metric_fn) def new_model_fn(features, labels, mode, config): spec = estimator.model_fn(features, labels, mode, config) if mode != ModeKeys.EVAL: return spec new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions, config) all_metrics = spec.eval_metric_ops or {} all_metrics.update(new_metrics) return spec._replace(eval_metric_ops=all_metrics) return estimator_lib.Estimator( model_fn=new_model_fn, model_dir=estimator.model_dir, config=estimator.config, # pylint: disable=protected-access warm_start_from=estimator._warm_start_settings) # pylint: enable=protected-access def _verify_metric_fn_args(metric_fn): args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % (metric_fn, invalid_args)) def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features if 'labels' in metric_fn_args: kwargs['labels'] = labels if 'predictions' in metric_fn_args: kwargs['predictions'] = predictions if 'config' in metric_fn_args: kwargs['config'] = config return metric_fn(**kwargs)