/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
you may not use this file except in compliance with the License.
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
distributed under the License is distributed on an "AS IS" BASIS,
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
See the License for the specific language governing permissions and
|
limitations under the License.
|
==============================================================================*/
|
|
#include "tensorflow/core/framework/graph_def_util.h"
|
|
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/node_def_builder.h"
|
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def_builder.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/util/equal_graph_def.h"
|
|
namespace tensorflow {
|
namespace {
|
|
Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) {
|
OpRegistrationData op_reg_data;
|
const Status s = b.Finalize(&op_reg_data);
|
*op_def = op_reg_data.op_def;
|
return s;
|
}
|
|
// Producer and consumer have default for an attr -> graph unchanged.
|
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeWithDefault) {
|
OpList op_list;
|
TF_ASSERT_OK(
|
FinalizeOpDef(OpDefBuilder("NoChangeWithDefault").Attr("a: int = 12"),
|
op_list.add_op()));
|
OpListOpRegistry registry(&op_list);
|
|
GraphDef graph_def;
|
TF_ASSERT_OK(NodeDefBuilder("ncwd", "NoChangeWithDefault", ®istry)
|
.Finalize(graph_def.add_node()));
|
GraphDef expected_graph_def = graph_def;
|
|
std::set<std::pair<string, string>> op_attr_removed;
|
TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
|
&op_attr_removed));
|
|
TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
|
EXPECT_TRUE(op_attr_removed.empty());
|
}
|
|
// Producer and consumer both have an attr -> graph unchanged.
|
TEST(RemoveNewDefaultAttrsFromGraphDefTest, NoChangeNoDefault) {
|
OpList op_list;
|
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("NoChangeNoDefault").Attr("a: int"),
|
op_list.add_op()));
|
OpListOpRegistry registry(&op_list);
|
|
GraphDef graph_def;
|
TF_ASSERT_OK(NodeDefBuilder("ncnd", "NoChangeNoDefault", ®istry)
|
.Attr("a", 42)
|
.Finalize(graph_def.add_node()));
|
GraphDef expected_graph_def = graph_def;
|
|
std::set<std::pair<string, string>> op_attr_removed;
|
TF_ASSERT_OK(RemoveNewDefaultAttrsFromGraphDef(&graph_def, registry, registry,
|
&op_attr_removed));
|
|
TF_EXPECT_GRAPH_EQ(expected_graph_def, graph_def);
|
EXPECT_TRUE(op_attr_removed.empty());
|
}
|
|
// Producer has default for an attr that the consumer does not know
|
// about, and the produced graph has the default value for the attr ->
|
// attr removed from graph (and so able to be consumed).
|
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UsesDefault) {
|
OpList consumer_op_list;
|
TF_ASSERT_OK(
|
FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
|
OpListOpRegistry consumer_registry(&consumer_op_list);
|
|
OpList producer_op_list;
|
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
|
producer_op_list.add_op()));
|
OpListOpRegistry producer_registry(&producer_op_list);
|
|
GraphDef produced_graph_def;
|
TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &producer_registry)
|
.Finalize(produced_graph_def.add_node()));
|
|
std::set<std::pair<string, string>> op_attr_removed;
|
TF_ASSERT_OK(
|
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
producer_registry, &op_attr_removed));
|
|
GraphDef expected_graph_def;
|
TF_ASSERT_OK(NodeDefBuilder("uses_default", "UsesDefault", &consumer_registry)
|
.Finalize(expected_graph_def.add_node()));
|
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
|
std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
|
EXPECT_EQ(expected_removed, op_attr_removed);
|
}
|
|
// Producer has default for an attr that the consumer does not know
|
// about, graph sets the attr to a value different from the default ->
|
// graph unchanged (but not able to be consumed by consumer).
|
TEST(RemoveNewDefaultAttrsFromGraphDefTest, ChangedFromDefault) {
|
OpList consumer_op_list;
|
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
|
consumer_op_list.add_op()));
|
OpListOpRegistry consumer_registry(&consumer_op_list);
|
|
OpList producer_op_list;
|
TF_ASSERT_OK(
|
FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
|
producer_op_list.add_op()));
|
OpListOpRegistry producer_registry(&producer_op_list);
|
|
GraphDef produced_graph_def;
|
TF_ASSERT_OK(NodeDefBuilder("changed_from_default", "ChangedFromDefault",
|
&producer_registry)
|
.Attr("a", 9)
|
.Finalize(produced_graph_def.add_node()));
|
GraphDef expected_graph_def = produced_graph_def;
|
|
std::set<std::pair<string, string>> op_attr_removed;
|
TF_ASSERT_OK(
|
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
producer_registry, &op_attr_removed));
|
|
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
EXPECT_TRUE(op_attr_removed.empty());
|
}
|
|
// Attrs starting with underscores should not be removed.
|
TEST(RemoveNewDefaultAttrsFromGraphDefTest, UnderscoreAttrs) {
|
OpList consumer_op_list;
|
TF_ASSERT_OK(
|
FinalizeOpDef(OpDefBuilder("Underscore"), consumer_op_list.add_op()));
|
OpListOpRegistry consumer_registry(&consumer_op_list);
|
|
OpList producer_op_list;
|
TF_ASSERT_OK(
|
FinalizeOpDef(OpDefBuilder("Underscore"), producer_op_list.add_op()));
|
// Add the _underscore attr manually since OpDefBuilder would complain
|
OpDef::AttrDef* attr = producer_op_list.mutable_op(0)->add_attr();
|
attr->set_name("_underscore");
|
attr->set_type("int");
|
attr->mutable_default_value()->set_i(17);
|
OpListOpRegistry producer_registry(&producer_op_list);
|
|
GraphDef produced_graph_def;
|
TF_ASSERT_OK(NodeDefBuilder("node", "Underscore", &producer_registry)
|
.Attr("_underscore", 17)
|
.Finalize(produced_graph_def.add_node()));
|
GraphDef expected_graph_def = produced_graph_def;
|
|
std::set<std::pair<string, string>> op_attr_removed;
|
TF_ASSERT_OK(
|
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
producer_registry, &op_attr_removed));
|
|
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
EXPECT_EQ(op_attr_removed.size(), 0);
|
}
|
|
TEST(RemoveNewDefaultAttrsFromGraphDefTest, HasFunction) {
|
OpList consumer_op_list;
|
TF_ASSERT_OK(
|
FinalizeOpDef(OpDefBuilder("UsesDefault"), consumer_op_list.add_op()));
|
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("ChangedFromDefault"),
|
consumer_op_list.add_op()));
|
OpListOpRegistry consumer_registry(&consumer_op_list);
|
|
OpList producer_op_list;
|
TF_ASSERT_OK(FinalizeOpDef(OpDefBuilder("UsesDefault").Attr("a: int = 17"),
|
producer_op_list.add_op()));
|
TF_ASSERT_OK(
|
FinalizeOpDef(OpDefBuilder("ChangedFromDefault").Attr("a: int = 17"),
|
producer_op_list.add_op()));
|
OpListOpRegistry producer_registry(&producer_op_list);
|
|
GraphDef produced_graph_def;
|
*produced_graph_def.mutable_library()->add_function() =
|
FunctionDefHelper::Create(
|
"my_func", {}, {}, {},
|
{{{"x"}, "UsesDefault", {}, {{"a", 17}}},
|
{{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
|
{});
|
OpList function_op_list;
|
*function_op_list.add_op() =
|
produced_graph_def.library().function(0).signature();
|
OpListOpRegistry function_registry(&function_op_list);
|
TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
|
.Finalize(produced_graph_def.add_node()));
|
|
std::set<std::pair<string, string>> op_attr_removed;
|
TF_ASSERT_OK(
|
RemoveNewDefaultAttrsFromGraphDef(&produced_graph_def, consumer_registry,
|
producer_registry, &op_attr_removed));
|
|
GraphDef expected_graph_def;
|
*expected_graph_def.mutable_library()->add_function() =
|
FunctionDefHelper::Create(
|
"my_func", {}, {}, {},
|
{{{"x"}, "UsesDefault", {}, {}},
|
{{"y"}, "ChangedFromDefault", {}, {{"a", 99}}}},
|
{});
|
TF_ASSERT_OK(NodeDefBuilder("call_func", "my_func", &function_registry)
|
.Finalize(expected_graph_def.add_node()));
|
TF_EXPECT_GRAPH_EQ(expected_graph_def, produced_graph_def);
|
EXPECT_EQ(expected_graph_def.library().DebugString(),
|
produced_graph_def.library().DebugString());
|
|
std::set<std::pair<string, string>> expected_removed({{"UsesDefault", "a"}});
|
EXPECT_EQ(expected_removed, op_attr_removed);
|
}
|
|
TEST(StrippedOpListForGraphTest, FlatTest) {
|
// Make four ops
|
OpList op_list;
|
for (const string& op : {"A", "B", "C", "D"}) {
|
OpDef* op_def = op_list.add_op();
|
op_def->set_name(op);
|
op_def->set_summary("summary");
|
op_def->set_description("description");
|
op_def->set_is_commutative(op == "B");
|
}
|
|
// Make a graph which uses two ops once and twice, respectively.
|
// The result should be independent of the ordering.
|
const string graph_ops[4][3] = {
|
{"C", "B", "B"}, {"B", "C", "B"}, {"B", "B", "C"}, {"C", "C", "B"}};
|
for (const bool use_function : {false, true}) {
|
for (int order = 0; order < 4; order++) {
|
GraphDef graph_def;
|
if (use_function) {
|
FunctionDef* function_def = graph_def.mutable_library()->add_function();
|
function_def->mutable_signature()->set_name("F");
|
for (const string& op : graph_ops[order]) {
|
function_def->add_node_def()->set_op(op);
|
}
|
graph_def.add_node()->set_op("F");
|
} else {
|
for (const string& op : graph_ops[order]) {
|
string name = strings::StrCat("name", graph_def.node_size());
|
NodeDef* node = graph_def.add_node();
|
node->set_name(name);
|
node->set_op(op);
|
}
|
}
|
|
// Strip the op list
|
OpList stripped_op_list;
|
TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
|
&stripped_op_list));
|
|
// We should have exactly two ops: B and C.
|
ASSERT_EQ(stripped_op_list.op_size(), 2);
|
for (int i = 0; i < 2; i++) {
|
const OpDef& op = stripped_op_list.op(i);
|
EXPECT_EQ(op.name(), i ? "C" : "B");
|
EXPECT_EQ(op.summary(), "");
|
EXPECT_EQ(op.description(), "");
|
EXPECT_EQ(op.is_commutative(), !i);
|
}
|
|
// Should get the same result using OpsUsedByGraph().
|
std::set<string> used_ops;
|
OpsUsedByGraph(graph_def, &used_ops);
|
ASSERT_EQ(std::set<string>({"B", "C"}), used_ops);
|
}
|
}
|
}
|
|
TEST(StrippedOpListForGraphTest, NestedFunctionTest) {
|
// Make a primitive op A.
|
OpList op_list;
|
op_list.add_op()->set_name("A");
|
|
for (const bool recursive : {false, true}) {
|
// Call A from function B, and B from function C.
|
GraphDef graph_def;
|
FunctionDef* b = graph_def.mutable_library()->add_function();
|
FunctionDef* c = graph_def.mutable_library()->add_function();
|
b->mutable_signature()->set_name("B");
|
c->mutable_signature()->set_name("C");
|
b->add_node_def()->set_op("A");
|
c->add_node_def()->set_op("B");
|
if (recursive) {
|
b->add_node_def()->set_op("B");
|
c->add_node_def()->set_op("C");
|
}
|
|
// Use C in the graph.
|
graph_def.add_node()->set_op("C");
|
|
// The stripped op list should contain just A.
|
OpList stripped_op_list;
|
TF_ASSERT_OK(StrippedOpListForGraph(graph_def, OpListOpRegistry(&op_list),
|
&stripped_op_list));
|
ASSERT_EQ(stripped_op_list.op_size(), 1);
|
ASSERT_EQ(stripped_op_list.op(0).name(), "A");
|
|
// Should get the same result using OpsUsedByGraph().
|
std::set<string> used_ops;
|
OpsUsedByGraph(graph_def, &used_ops);
|
ASSERT_EQ(std::set<string>({"A"}), used_ops);
|
}
|
}
|
|
} // namespace
|
} // namespace tensorflow
|