# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
#
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
# you may not use this file except in compliance with the License.
|
# You may obtain a copy of the License at
|
#
|
# http://www.apache.org/licenses/LICENSE-2.0
|
#
|
# Unless required by applicable law or agreed to in writing, software
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# See the License for the specific language governing permissions and
|
# limitations under the License.
|
# ==============================================================================
|
"""Tests for ops which manipulate lists of tensors via bridge."""
|
|
# pylint: disable=g-bad-name
|
from __future__ import absolute_import
|
from __future__ import division
|
from __future__ import print_function
|
import os
|
import numpy as np
|
from tensorflow.compiler.tests import xla_test
|
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import errors
|
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import list_ops
|
from tensorflow.python.platform import test
|
|
|
class ListOpsTest(xla_test.XLATestCase):
|
|
def testElementShape(self):
|
with self.cached_session() as sess, self.test_scope():
|
dim = array_ops.placeholder(dtypes.int32)
|
l = list_ops.empty_tensor_list(
|
element_shape=(dim, 15),
|
element_dtype=dtypes.float32,
|
max_num_elements=20)
|
e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
|
e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
|
self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
|
self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
|
|
def testPushPop(self):
|
with self.cached_session() as sess, self.test_scope():
|
l = list_ops.empty_tensor_list(
|
element_shape=(7, 15),
|
element_dtype=dtypes.float32,
|
max_num_elements=10)
|
l = list_ops.tensor_list_push_back(
|
l, constant_op.constant(1.0, shape=(7, 15)))
|
l = list_ops.tensor_list_push_back(
|
l, constant_op.constant(2.0, shape=(7, 15)))
|
l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
_, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(sess.run(e2), 2.0 * np.ones((7, 15)))
|
self.assertAllEqual(sess.run(e1), 1.0 * np.ones((7, 15)))
|
|
def testDoNotConstantFoldVariants(self):
|
with self.cached_session() as sess, self.test_scope():
|
val = array_ops.placeholder(dtype=dtypes.float32)
|
l = list_ops.empty_tensor_list(
|
element_shape=(7, 15),
|
element_dtype=dtypes.float32,
|
max_num_elements=10)
|
# Note: Pushing a Placeholder will force the constant folding code
|
# to build a Const node with a DT_VARIANT output. This tests that XLA
|
# passes a cf_consider_fn which prevent folding such nodes.
|
l = list_ops.tensor_list_push_back(
|
l, array_ops.fill(value=val, dims=(7, 15)))
|
l = list_ops.tensor_list_push_back(
|
l, constant_op.constant(2.0, shape=(7, 15)))
|
l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
_, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(sess.run(e2, {val: 1.0}), 2.0 * np.ones((7, 15)))
|
self.assertAllEqual(sess.run(e1, {val: 1.0}), 1.0 * np.ones((7, 15)))
|
|
def testPushPopSeparateLists(self):
|
with self.cached_session() as sess, self.test_scope():
|
l = list_ops.empty_tensor_list(
|
element_shape=[],
|
element_dtype=dtypes.float32,
|
max_num_elements=20)
|
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
|
l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
|
l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0))
|
_, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
|
l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
|
l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
|
l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
|
result = sess.run([e11, [e21, e22], [e31, e32]])
|
self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
|
|
def testEmptyTensorListNoMax(self):
|
with self.cached_session() as sess, self.test_scope():
|
l = list_ops.empty_tensor_list(
|
element_shape=(7, 15), element_dtype=dtypes.float32)
|
l = list_ops.tensor_list_push_back(
|
l, constant_op.constant(1.0, shape=(7, 15)))
|
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
"Set the max number of elements"):
|
self.assertAllEqual(sess.run(e), 1.0 * np.ones((7, 15)))
|
|
def testEmptyTensorListMax(self):
|
with self.cached_session() as sess, self.test_scope():
|
l = list_ops.empty_tensor_list(
|
element_shape=(10, 15), element_dtype=dtypes.float32,
|
max_num_elements=2)
|
l = list_ops.tensor_list_push_back(
|
l, array_ops.fill(value=3.0, dims=(10, 15)))
|
_, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(sess.run(e), 3.0 * np.ones((10, 15)))
|
|
def testListFromTensor(self):
|
with self.cached_session(), self.test_scope():
|
t = constant_op.constant([1.0, 2.0])
|
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
|
self.assertAllEqual(e, 1.0)
|
l, e0 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(e0, 2.0)
|
l, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(e1, 1.0)
|
self.assertAllEqual(list_ops.tensor_list_length(l), 0)
|
|
def testGetSet(self):
|
with self.cached_session(), self.test_scope():
|
t = constant_op.constant([1.0, 2.0])
|
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
|
self.assertAllEqual(e0, 1.0)
|
l = list_ops.tensor_list_set_item(l, 0, 3.0)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(t, [3.0, 2.0])
|
|
def testSetDoesNotUpdatePushIndex(self):
|
with self.cached_session(), self.test_scope():
|
l = list_ops.empty_tensor_list(
|
element_shape=[], element_dtype=dtypes.float32, max_num_elements=2)
|
# SetItem should not change the push index.
|
l = list_ops.tensor_list_set_item(l, 1, 3.)
|
l = list_ops.tensor_list_push_back(l, 5.)
|
l = list_ops.tensor_list_push_back(l, 7.)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(t, [5., 7.])
|
|
def testGetSetReserved(self):
|
with self.cached_session(), self.test_scope():
|
l = list_ops.tensor_list_reserve(
|
element_dtype=dtypes.float32, element_shape=[], num_elements=2)
|
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
|
self.assertAllEqual(e0, 0.0)
|
l = list_ops.tensor_list_set_item(l, 0, 3.0)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(t, [3.0, 0.0])
|
|
def testSetStackReservedUnknownElementShape(self):
|
with self.cached_session(), self.test_scope():
|
l = list_ops.tensor_list_reserve(
|
element_dtype=dtypes.float32, element_shape=None, num_elements=2)
|
l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
|
|
def testPushInEmptyListWithUnknownElementShape(self):
|
with self.cached_session(), self.test_scope():
|
l = list_ops.empty_tensor_list(
|
element_dtype=dtypes.float32, element_shape=None, max_num_elements=2)
|
l = list_ops.tensor_list_push_back(l, [3.0, 4.0])
|
# Pushing an element with a different shape should raise an error.
|
with self.assertRaisesRegexp(errors.InternalError, "shape"):
|
l = list_ops.tensor_list_push_back(l, 5.)
|
self.evaluate(
|
list_ops.tensor_list_stack(l, element_dtype=dtypes.float32))
|
|
def testGetSetReservedNonScalar(self):
|
with self.cached_session() as sess, self.test_scope():
|
l = list_ops.tensor_list_reserve(
|
element_dtype=dtypes.float32,
|
element_shape=(7, 15),
|
num_elements=2)
|
l = list_ops.tensor_list_set_item(
|
l, 0, constant_op.constant(1.0, shape=(7, 15)))
|
e1 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
|
e2 = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
|
self.assertAllEqual(sess.run(e1), np.ones((7, 15)))
|
self.assertAllEqual(sess.run(e2), np.zeros((7, 15)))
|
|
def testStack(self):
|
with self.cached_session(), self.test_scope():
|
l = list_ops.empty_tensor_list(
|
element_dtype=dtypes.float32,
|
element_shape=[],
|
max_num_elements=2)
|
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
|
e = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
|
self.assertAllEqual(e, 1.0)
|
l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(t.shape.as_list(), [None])
|
self.assertAllEqual(t, [1.0, 2.0])
|
|
def testStackWithUninitializedTensors(self):
|
with self.cached_session(), self.test_scope():
|
l = list_ops.tensor_list_reserve(
|
element_dtype=dtypes.float32, element_shape=[], num_elements=3)
|
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
|
self.assertAllEqual(t, [0., 0., 0.])
|
|
if __name__ == "__main__":
|
os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' +
|
os.environ.get('TF_XLA_FLAGS', ''))
|
test.main()
|