/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/platform/test.h"
|
|
namespace tensorflow {
|
namespace {
|
|
TEST(PartialTensorShapeTest, Default) {
|
// The default PartialTensorShape constructor constructs a shape
|
// with unknown rank.
|
const PartialTensorShape s;
|
EXPECT_EQ(s.dims(), -1);
|
EXPECT_TRUE(s.unknown_rank());
|
}
|
|
TEST(PartialTensorShapeTest, Concatenate) {
|
const PartialTensorShape s({10, 5});
|
ASSERT_EQ(2, s.dims());
|
EXPECT_EQ(10, s.dim_size(0));
|
EXPECT_EQ(5, s.dim_size(1));
|
EXPECT_EQ(50, s.num_elements());
|
|
const auto s1 = s.Concatenate(s);
|
ASSERT_EQ(4, s1.dims());
|
EXPECT_EQ(10, s1.dim_size(0));
|
EXPECT_EQ(5, s1.dim_size(1));
|
EXPECT_EQ(10, s1.dim_size(2));
|
EXPECT_EQ(5, s1.dim_size(3));
|
EXPECT_EQ(50 * 50, s1.num_elements());
|
|
const auto s2 = s.Concatenate(-1);
|
const auto s3 = s2.Concatenate(0);
|
ASSERT_EQ(3, s2.dims());
|
ASSERT_EQ(4, s3.dims());
|
EXPECT_EQ(10, s2.dim_size(0));
|
EXPECT_EQ(10, s3.dim_size(0));
|
EXPECT_EQ(5, s2.dim_size(1));
|
EXPECT_EQ(5, s3.dim_size(1));
|
EXPECT_EQ(-1, s2.dim_size(2));
|
EXPECT_EQ(-1, s3.dim_size(2));
|
EXPECT_EQ(0, s3.dim_size(3));
|
EXPECT_EQ(-1, s2.num_elements());
|
EXPECT_EQ(-1, s3.num_elements());
|
|
const auto s4 = s.Concatenate(PartialTensorShape());
|
EXPECT_EQ(-1, s4.dims());
|
EXPECT_EQ(-1, s4.num_elements());
|
}
|
|
TEST(PartialTensorShapeTest, InvalidShapeProto) {
|
TensorShapeProto proto;
|
EXPECT_TRUE(PartialTensorShape::IsValid(proto));
|
|
proto.add_dim()->set_size(357);
|
proto.add_dim()->set_size(982);
|
EXPECT_TRUE(PartialTensorShape::IsValid(proto));
|
|
proto.Clear();
|
proto.add_dim()->set_size(0);
|
proto.add_dim()->set_size(-1);
|
EXPECT_TRUE(PartialTensorShape::IsValid(proto));
|
|
proto.Clear();
|
proto.set_unknown_rank(true);
|
EXPECT_TRUE(PartialTensorShape::IsValid(proto));
|
|
proto.add_dim()->set_size(1);
|
EXPECT_FALSE(PartialTensorShape::IsValid(proto));
|
|
proto.Clear();
|
proto.add_dim()->set_size(-2);
|
EXPECT_FALSE(PartialTensorShape::IsValid(proto));
|
}
|
|
TEST(PartialTensorShapeTest, PartialShapeFullyDefined) {
|
const PartialTensorShape a({-1, 0, 1});
|
const PartialTensorShape b({1, 0, 1});
|
const PartialTensorShape c({-1, -1, 1});
|
const PartialTensorShape d({1, 0});
|
const PartialTensorShape e({});
|
const PartialTensorShape f;
|
EXPECT_FALSE(a.IsFullyDefined());
|
EXPECT_FALSE(c.IsFullyDefined());
|
EXPECT_TRUE(b.IsFullyDefined());
|
EXPECT_TRUE(d.IsFullyDefined());
|
EXPECT_TRUE(e.IsFullyDefined());
|
EXPECT_FALSE(f.IsFullyDefined());
|
}
|
|
TEST(PartialTensorShapeTest, ToTensorShape) {
|
const PartialTensorShape a({});
|
const PartialTensorShape b({1, 0});
|
const PartialTensorShape c({-1, 0});
|
const PartialTensorShape d;
|
TensorShape full;
|
EXPECT_TRUE(a.AsTensorShape(&full));
|
EXPECT_EQ(full.dims(), 0);
|
EXPECT_TRUE(b.AsTensorShape(&full));
|
EXPECT_EQ(full.dims(), 2);
|
EXPECT_EQ(full.dim_size(0), 1);
|
EXPECT_EQ(full.dim_size(1), 0);
|
EXPECT_FALSE(c.AsTensorShape(&full));
|
EXPECT_FALSE(d.AsTensorShape(&full));
|
}
|
|
TEST(PartialTensorShapeTest, PartialShapeIdenticalTo) {
|
const PartialTensorShape a({-1, 0, 1});
|
const PartialTensorShape b({1, 0, 1});
|
const PartialTensorShape c({-1, -1, 1});
|
const PartialTensorShape d({1, 0});
|
const PartialTensorShape e({-1, 0, 2});
|
const PartialTensorShape f({});
|
const PartialTensorShape g;
|
std::vector<PartialTensorShape> shapes = {a, b, c, d, e, f, g};
|
for (int i = 0; i < shapes.size(); ++i) {
|
for (int j = 0; j < i; ++j) {
|
if (i == j) {
|
EXPECT_TRUE(shapes[i].IsIdenticalTo(shapes[j]));
|
} else {
|
EXPECT_FALSE(shapes[i].IsIdenticalTo(shapes[j]));
|
}
|
}
|
}
|
}
|
|
TEST(PartialTensorShapeTest, PartialShapeCompatibleWith) {
|
const PartialTensorShape a({-1, 0, 1});
|
const PartialTensorShape b({1, 0, 1});
|
const PartialTensorShape c({-1, -1, 1});
|
const PartialTensorShape d({1, 0});
|
const PartialTensorShape e({-1, 0, 2});
|
const PartialTensorShape f({});
|
const PartialTensorShape g;
|
|
EXPECT_TRUE(f.IsCompatibleWith(f));
|
EXPECT_TRUE(a.IsCompatibleWith(b));
|
EXPECT_TRUE(a.IsCompatibleWith(a));
|
EXPECT_TRUE(b.IsCompatibleWith(b));
|
EXPECT_TRUE(a.IsCompatibleWith(c));
|
EXPECT_TRUE(b.IsCompatibleWith(c));
|
EXPECT_FALSE(a.IsCompatibleWith(d));
|
EXPECT_FALSE(b.IsCompatibleWith(d));
|
EXPECT_FALSE(c.IsCompatibleWith(d));
|
EXPECT_FALSE(a.IsCompatibleWith(e));
|
EXPECT_FALSE(b.IsCompatibleWith(e));
|
EXPECT_FALSE(c.IsCompatibleWith(e));
|
EXPECT_FALSE(a.IsCompatibleWith(f));
|
EXPECT_FALSE(b.IsCompatibleWith(f));
|
EXPECT_FALSE(c.IsCompatibleWith(f));
|
EXPECT_TRUE(a.IsCompatibleWith(g));
|
EXPECT_TRUE(g.IsCompatibleWith(a));
|
EXPECT_TRUE(g.IsCompatibleWith(g));
|
}
|
|
TEST(PartialTensorShapeTest, ShapeCompatibleWith) {
|
const PartialTensorShape a({-1, 0, 1});
|
const PartialTensorShape unknown;
|
TensorShape b({0, 1});
|
TensorShape c({0, 0, 1});
|
TensorShape d({1, 0, 1});
|
TensorShape e({1, 1, 1});
|
|
EXPECT_FALSE(a.IsCompatibleWith(b));
|
EXPECT_TRUE(a.IsCompatibleWith(c));
|
EXPECT_TRUE(a.IsCompatibleWith(d));
|
EXPECT_FALSE(a.IsCompatibleWith(e));
|
|
EXPECT_TRUE(unknown.IsCompatibleWith(b));
|
EXPECT_TRUE(unknown.IsCompatibleWith(c));
|
EXPECT_TRUE(unknown.IsCompatibleWith(d));
|
EXPECT_TRUE(unknown.IsCompatibleWith(e));
|
}
|
|
TEST(PartialTensorShapeTest, PartialShapeMergeWith) {
|
const PartialTensorShape a({-1, 0, 1});
|
const PartialTensorShape b({1, 0, 1});
|
const PartialTensorShape c({-1, -1, 1});
|
const PartialTensorShape d({1, 0});
|
const PartialTensorShape e({-1, 0, 2});
|
const PartialTensorShape f({});
|
const PartialTensorShape g;
|
|
PartialTensorShape test;
|
EXPECT_EQ(Status::OK(), a.MergeWith(a, &test));
|
EXPECT_EQ(test.dims(), 3);
|
EXPECT_EQ(test.dim_size(0), -1);
|
EXPECT_EQ(test.dim_size(1), 0);
|
EXPECT_EQ(test.dim_size(2), 1);
|
|
test = PartialTensorShape();
|
EXPECT_EQ(Status::OK(), a.MergeWith(b, &test));
|
EXPECT_EQ(test.dims(), 3);
|
EXPECT_EQ(test.dim_size(0), 1);
|
EXPECT_EQ(test.dim_size(1), 0);
|
EXPECT_EQ(test.dim_size(2), 1);
|
|
test = PartialTensorShape();
|
EXPECT_TRUE(errors::IsInvalidArgument(a.MergeWith(d, &test)));
|
|
test = PartialTensorShape();
|
EXPECT_EQ(Status::OK(), a.MergeWith(c, &test));
|
EXPECT_EQ(test.dims(), 3);
|
EXPECT_EQ(test.dim_size(0), -1);
|
EXPECT_EQ(test.dim_size(1), 0);
|
EXPECT_EQ(test.dim_size(2), 1);
|
|
test = PartialTensorShape();
|
EXPECT_EQ(Status::OK(), c.MergeWith(a, &test));
|
EXPECT_EQ(test.dims(), 3);
|
EXPECT_EQ(test.dim_size(0), -1);
|
EXPECT_EQ(test.dim_size(1), 0);
|
EXPECT_EQ(test.dim_size(2), 1);
|
|
test = PartialTensorShape();
|
EXPECT_EQ(Status::OK(), a.MergeWith(g, &test));
|
EXPECT_EQ(test.dims(), 3);
|
EXPECT_EQ(test.dim_size(0), -1);
|
EXPECT_EQ(test.dim_size(1), 0);
|
EXPECT_EQ(test.dim_size(2), 1);
|
|
test = PartialTensorShape();
|
EXPECT_EQ(Status::OK(), g.MergeWith(a, &test));
|
EXPECT_EQ(test.dims(), 3);
|
EXPECT_EQ(test.dim_size(0), -1);
|
EXPECT_EQ(test.dim_size(1), 0);
|
EXPECT_EQ(test.dim_size(2), 1);
|
}
|
|
TEST(PartialTensorShapeTest, MakePartialShapeEmpty) {
|
// Empty made partial shapes should still be fully defined
|
const int64 dims[1] = {};
|
PartialTensorShape shape;
|
EXPECT_FALSE(shape.IsFullyDefined());
|
TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 0, &shape));
|
EXPECT_TRUE(shape.IsFullyDefined());
|
}
|
|
TEST(PartialTensorShapeTest, MakePartialShapeFull) {
|
// Check that arrays are copied through correctly
|
const int64 dims[3] = {7, -1, 2};
|
PartialTensorShape shape;
|
TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 3, &shape));
|
ASSERT_EQ(shape.dims(), 3);
|
for (int i = 0; i < 3; i++) {
|
EXPECT_EQ(shape.dim_size(i), dims[i]);
|
}
|
}
|
|
TEST(PartialTensorShapeTest, MakePartialShapeInvalid) {
|
// Check that arrays are copied through correctly
|
const int64 dims[3] = {7, -2, 2};
|
PartialTensorShape shape;
|
EXPECT_EQ(error::INVALID_ARGUMENT,
|
PartialTensorShape::MakePartialShape(dims, 3, &shape).code());
|
}
|
|
} // namespace
|
} // namespace tensorflow
|