ronnie
2022-10-14 1504bb53e29d3d46222c0b3ea994fc494b48e153
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for integration tests."""
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import functools
 
import tensorflow.compat.v2 as tf
 
from tensorflow.python.framework import smart_cond
from tensorflow.python.util import tf_inspect
 
 
# TODO(vbardiovsky): We should just reuse Keras's Lambda layer, when that
# enables to get trainable variables.
class CustomLayer(tf.keras.layers.Layer):
  """Wraps callable object as a `Layer` object.
 
  Args:
    func: The callable object to wrap. Layer inputs are passed as the first
      positional argument. If `func` accepts a `training` argument, a Python
      boolean is passed for it.
      If present, the following attributes of `func` have a special meaning:
        * variables: a list of all tf.Variable objects that `func` depends on.
        * trainable_variables: those elements of `variables` that are reported
          as trainable variables of this Keras Layer.
        * regularization_losses: a list of callables to be added as losses
          of this Keras layer. Each one must accept zero arguments and return
          a scalare tensor.
    trainable: Boolean controlling whether the trainable variables of `func`
      are reported as trainable variables of this layer.
    arguments: optionally, a dict with additional keyword arguments passed
      to `func`.
    **kwargs: 'output_shape': A tuple with the (possibly partial) output
      shape of the callable *without* leading batch size. Other arguments
      are pass into the Layer constructor.
  """
 
  def __init__(self, func, trainable=False, arguments=None, **kwargs):
    # Set self._{non,}_trainable_weights before calling Layer.__init__.
    if hasattr(func, 'trainable_variables'):
      self._trainable_weights = [v for v in func.trainable_variables]
      trainable_variables_set = set(func.trainable_variables)
    else:
      self._trainable_weights = []
      trainable_variables_set = set()
    if hasattr(func, 'variables'):
      self._non_trainable_weights = [v for v in func.variables
                                     if v not in trainable_variables_set]
    else:
      self._non_trainable_weights = []  # TODO(arnoegw): Infer from `func`.
 
    # TODO(b/124219898): We should be able to get the embedding dimension from
    # the restored model.
    if 'output_shape' in kwargs:
      self._output_shape = tuple(kwargs.pop('output_shape'))
 
    super(CustomLayer, self).__init__(trainable=trainable, **kwargs)
    # Prepare to call `func`.
    self._func = func
    self._func_fullargspec = tf_inspect.getfullargspec(func.__call__)
    self._func_wants_training = (
        'training' in self._func_fullargspec.args or
        'training' in self._func_fullargspec.kwonlyargs)
    self._arguments = arguments or {}
    # Forward the callable's regularization losses (if any).
    if hasattr(func, 'regularization_losses'):
      for l in func.regularization_losses:
        if not callable(l):
          raise ValueError(
              'CustomLayer(func) expects func.regularization_losses to be an '
              'iterable of callables, each returning a scalar loss term.')
        self.add_loss(l)  # Supports callables.
 
  def call(self, x, training=None):
    # We basically want to call this...
    f = functools.partial(self._func, x, **self._arguments)
    # ...but we may also have to pass a Python boolean for `training`.
    if not self._func_wants_training:
      result = f()
    else:
      if training is None:
        training = tf.keras.backend.learning_phase()  # Could be a tensor.
      result = smart_cond.smart_cond(training,
                                     lambda: f(training=True),
                                     lambda: f(training=False))
    # TODO(b/124219898): Polymorphic function should return shaped tensor.
    if hasattr(self, '_output_shape'):
      result.set_shape((x.shape[0],) + self._output_shape)
    return result